diff --git a/.ci/scripts/select_sklearn_tests.py b/.ci/scripts/select_sklearn_tests.py index c69d20e413..868e74cee3 100644 --- a/.ci/scripts/select_sklearn_tests.py +++ b/.ci/scripts/select_sklearn_tests.py @@ -52,6 +52,7 @@ def parse_tests_tree(entry, prefix=""): "model_selection/tests": ["test_split.py", "test_validation.py"], "neighbors/tests": ["test_lof.py", "test_neighbors.py", "test_neighbors_pipeline.py"], "svm/tests": ["test_sparse.py", "test_svm.py"], + "tests": "test_dummy.py", } if sklearn_check_version("1.2"): tests_map["tests"] = ["test_public_functions.py"] diff --git a/onedal/__init__.py b/onedal/__init__.py index 7409f7bc6b..efa7076f48 100644 --- a/onedal/__init__.py +++ b/onedal/__init__.py @@ -122,6 +122,7 @@ def __repr__(self) -> str: "_spmd_backend", "covariance", "decomposition", + "dummy", "ensemble", "neighbors", "primitives", diff --git a/onedal/dal.cpp b/onedal/dal.cpp index d0f25fb62a..80917c60b6 100644 --- a/onedal/dal.cpp +++ b/onedal/dal.cpp @@ -78,6 +78,7 @@ ONEDAL_PY_INIT_MODULE(logistic_regression); #if defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240700 ONEDAL_PY_INIT_MODULE(finiteness_checker); #endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240700 +ONEDAL_PY_INIT_MODULE(dummy); #endif // ONEDAL_DATA_PARALLEL_SPMD #ifdef ONEDAL_DATA_PARALLEL_SPMD @@ -138,6 +139,7 @@ PYBIND11_MODULE(_onedal_py_host, m) { #if defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240700 init_finiteness_checker(m); #endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240700 + init_dummy(m); } #endif // ONEDAL_DATA_PARALLEL_SPMD diff --git a/onedal/dummy/__init__.py b/onedal/dummy/__init__.py new file mode 100644 index 0000000000..d92fb51813 --- /dev/null +++ b/onedal/dummy/__init__.py @@ -0,0 +1,19 @@ +# ============================================================================== +# Copyright Contributors to the oneDAL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from .dummy import DummyEstimator + +__all__ = ["DummyEstimator"] diff --git a/onedal/dummy/dummy.cpp b/onedal/dummy/dummy.cpp new file mode 100644 index 0000000000..66961159e8 --- /dev/null +++ b/onedal/dummy/dummy.cpp @@ -0,0 +1,195 @@ +/******************************************************************************* +* Copyright Contributors to the oneDAL Project +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "onedal/common.hpp" +#include "onedal/version.hpp" +// A fake oneDAL algorithm is include via the `dummy_onedal.hpp` header. In +// normal circumstances a header for the oneDAL algorithm would be +// included here from the oneDAL `oneapi/dal/algo/` folder. +#include "onedal/dummy/dummy_onedal.hpp" +#include "oneapi/dal/table/common.hpp" +#include "oneapi/dal/table/homogen.hpp" + +namespace py = pybind11; + +// oneDAL-python interaction code is located in this namespace +namespace oneapi::dal::python { + +// pybind11 structures and functions of the 'dummy' algorithm +namespace dummy { + +template +struct method2t { + method2t(const Task& task, const Ops& ops) : ops(ops) {} + // this functor converts the method param into a valid oneDAL task. + // Tasks are specific to each algorithm, therefore method2t is often + // defined for each algo. + template + auto operator()(const py::dict& params) { + using namespace dal::dummy; + const auto method = params["method"].cast(); + + ONEDAL_PARAM_DISPATCH_VALUE(method, "dense", ops, Float, method::dense); + ONEDAL_PARAM_DISPATCH_VALUE(method, "by_default", ops, Float, method::by_default); + ONEDAL_PARAM_DISPATCH_THROW_INVALID_VALUE(method); + } + + Ops ops; +}; + +struct params2desc { + // This functor converts the params dictionary into a oneDAL descriptor + template + auto operator()(const py::dict& params) { + auto desc = dal::dummy::descriptor(); + + // conversion of the params dict to oneDAL params occurs here except + // for the ``method`` and ``fptype`` parameters. They are assigned + // to the descriptor individually here before returning. + const auto constant = params["constant"].cast(); + desc.set_constant(constant); + + return desc; + } +}; + +// the following functions define the python interface methods for the +// oneDAL algorithms. They are templated for the policy (which may be host, +// dpc, or spmd), and task, which is defined per algorithm. They are all +// defined using lambda functions (a common occurrence for pybind11), but +// that is not a requirement. +template +void init_train_ops(py::module& m) { + m.def("train", [](const Policy& policy, const py::dict& params, const table& data) { + using namespace dal::dummy; + using input_t = train_input; + // while there is a train_ops defined for each oneDAL algorithm + // which supports ``train``, this is the train_ops defined in + // onedal/common/dispatch_utils.hpp + train_ops ops(policy, input_t{ data }, params2desc{}); + // fptype2t is defined in common/dispatch_utils.hpp + // which operates in a similar manner to the method2t functor + // it selects the floating point datatype for the calculation + return fptype2t{ method2t{ Task{}, ops } }(params); + }); +}; + +template +void init_infer_ops(py::module_& m) { + m.def( + "infer", + [](const Policy& policy, const py::dict& params, const table& constant, const table& data) { + using namespace dal::dummy; + using input_t = infer_input; + + infer_ops ops(policy, input_t{ data, constant }, params2desc{}); + // with the use of functors the order of operations is as + // follows: Task is generated, the ops is already created above, + // method2t is constructed, and then fptype2t is constructed. + // It is then evaluated in opposite order sequentially on the + // params dict. + return fptype2t{ method2t{ Task{}, ops } }(params); + }); +} + +// This defines the result C++ objects for use in python via pybind11. +// Result object attributes should be pybind11 native types (like int, +// float, etc.) or oneDAL tables. + +template +void init_train_result(py::module_& m) { + using namespace dal::dummy; + using result_t = train_result; + + py::class_(m, "train_result").def(py::init()).DEF_ONEDAL_PY_PROPERTY(data, result_t); +} + +template +void init_infer_result(py::module_& m) { + using namespace dal::dummy; + using result_t = infer_result; + + py::class_(m, "infer_result").def(py::init()).DEF_ONEDAL_PY_PROPERTY(data, result_t); +} + +ONEDAL_PY_DECLARE_INSTANTIATOR(init_train_result); +ONEDAL_PY_DECLARE_INSTANTIATOR(init_infer_result); +ONEDAL_PY_DECLARE_INSTANTIATOR(init_train_ops); +ONEDAL_PY_DECLARE_INSTANTIATOR(init_infer_ops); + +} // namespace dummy + +ONEDAL_PY_INIT_MODULE(dummy) { + using namespace dummy; + using namespace dal::detail; + using namespace dal::dummy; + + // the task_list allows for multiple types of tasks (like regression + // and classification) template to be evaluated. The use of 'types' + // is not required, and has special implications for the + // 'bind_default_backend' function as it creates submodules in python + // based on the task name. See the covariance implementation + // where no task_list is used and a submodule of the algorithm is not + // made. + using task_list = types; + auto sub = m.def_submodule("dummy"); + + // explicitly define the templates based off of the policy and task + // lists. These instantiations lead to a cascade of fully-resolved + // templates from oneDAL. It begins by fully resolving functors defined + // here and the oneDAL descriptor. It then fully specifies functors in + // common/dispatch_utils.hpp, which starts resolving oneDAL objects + // for the algorithm like the train_ops/infer_ops functors defined there. + // This leads to a fair number of compile time work with oneDAL headers. + // For example take init_train_ops in approximate reverse order + // (to show how it goes from here to oneDAL): + // + // 0. Creates pybind11 interface + // 1. Specifies lambda defined in init_train_ops + // 2. Specifies fptype2t + // 3. Specifies method2t + // 4. Specifies train_ops defined in common/dispatch_utils.hpp + // 5. Specifies train defined in oneapi/dal/train.hpp + // 6. Specifies train_dispatch in oneapi/dal/detail/train_ops.hpp + // 7. Specifies several functors in oneapi/dal/detail/ops_dispatcher.hpp + // 8. Specifies train_ops defined in algorithm's train_ops.hpp + // 9. Specifies oneDAL train_input, train_result and descriptor structs + /**** finally hits objects compiled in oneDAL for the computation ****/ + // (train_ops_dispatcher for example) + // + // Its not clear how many layers of these indirections are compiled + // versus optimized away. The namings in dispatch_utils.hpp are also + // unfortunate and confusing. + + // policy_list is defined elsewhere which is dependent on the backend + // which is being built. Placed within a macro-check in order to prevent + // use with an spmd policy. +#ifndef ONEDAL_DATA_PARALLEL_SPMD + ONEDAL_PY_INSTANTIATE(init_train_ops, sub, policy_list, task_list); + ONEDAL_PY_INSTANTIATE(init_infer_ops, sub, policy_list, task_list); + ONEDAL_PY_INSTANTIATE(init_train_result, sub, task_list); + ONEDAL_PY_INSTANTIATE(init_infer_result, sub, task_list); +#else + // This is where the pybind11 init functions would be instantiated with + // a policy_spmd object. For example, if an init_train_ops existed for + // the spmd backend it would be instantiated like: + // ONEDAL_PY_INSTANTIATE(init_train_ops, sub, policy_spmd, task_list); +#endif +} + +ONEDAL_PY_TYPE2STR(dal::dummy::task::generate, "generate"); + +} // namespace oneapi::dal::python diff --git a/onedal/dummy/dummy.py b/onedal/dummy/dummy.py new file mode 100644 index 0000000000..99621e0634 --- /dev/null +++ b/onedal/dummy/dummy.py @@ -0,0 +1,137 @@ +# ============================================================================== +# Copyright Contributors to the oneDAL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""This file describes necessary characteristics and design patterns of onedal estimators. + +This can be used as a foundation for developing other estimators. Most +comments guiding code development should be removed unless pertinent to the +implementation.""" + +from .._device_offload import supports_queue +from ..common._backend import bind_default_backend +from ..datatypes import from_table, to_table + + +class DummyEstimator: + # This class creates a constant 2d array of specific size as an example + + def __init__(self, constant=False): + # The __init__ method should only assign class attributes matching + # the input parameters (similar to sklearn). It is not to assign + # any attributes which aren't related to the operation of oneDAL. + # This is means that it should not conform to sklearn, only to + # oneDAL. Don't add unnecessary attributes which only match sklearn, + # these should be translated by the sklearnex estimator. In this case + # the only parameter for the dummy algorithm is the `constant` param. + self.constant = constant + self._onedal_model = None + + # see documentation on bind_default_backend. There exists three possible + # oneDAL pybind11 interfaces, 'host', 'dpc' and 'spmd'. These are for + # cpu-only, cpu and gpu, and multi-device computation respectively. Logic + # in the onedal module will determine which can be used at import time. + # It will attempt to use the `dpc` interface if possible (which enables + # gpu computation) but requires a SYCL runtime. If not possible it will + # silently fall back to the 'host' pybind11 interface. The backend + # binding logic will seamlessly handle this for the estimator. The 'spmd' + # backend is specific to onedal estimators defined in the 'spmd' folder. + # The binding makes the pybind11 function a method of this class with + # the same name (in this case ProtoTypeEstimator.compute should call + # the pybind11 function onedal.backend.dummy.generate.train) + # where backend can be one of 'host', 'dpc' or 'spmd'. + @bind_default_backend("dummy.generate") + def train(self, params, data_table): ... + + @bind_default_backend("dummy.generate") + def infer(self, params, model, data_table): ... + + @supports_queue + def fit(self, X, y, queue=None): + # convert the data to oneDAL tables in preparation for use by the + # oneDAL pybind11 interfaces/objects. + X_t, y_t = to_table(X, y) + + # Generating the params dict can be centralized into a class method, + # but it must be named ``_get_onedal_params``. Parameter 'fptype' is + # specific to the pybind11 interface, and cannot be found in oneDAL + # documentation. This tells oneDAL what float type to use for the + # computation. The safest and best way to assign this value is after + # the input data has been converted to a oneDAL table, as the dtype + # is standardized (taken care of by ``to_table``). This dtype is a + # ``numpy`` dtype due to its ubiquity and native support in pybind11. + params = { + "fptype": y_t.dtype, # normally X_t.dtype is used + "method": "dense", + "constant": self.constant, + } + + # This is the call to the oneDAL pybind11 backend, which was + # previously bound using ``bind_default_backend``. It returns a + # pybind11 Python interface to the oneDAL C++ result object. + result = self.train(params, y_t) + # In general the naming conventions of ``fit`` match to ``train``, + # and ``predict`` match oneDAL's ``infer``. Please refer to the oneDAL + # design documentation to determine the best translation (headers + # under oneDAL/tree/main/cpp/oneapi/dal in the oneDAL repository, + # like for other correlaries like ``compute`` and ``partial_train``. + # Generally the sklearn naming scheme for class methods should be + # used here, but calls to the pybind11 interfaces should follow + # oneDAL naming. + + # Oftentimes oneDAL table objects are attributes of the oneDAL C++ + # object. These can be converted into various common data frameworks + # like ``numpy`` or ``dpctl.tensor`` using ``from_table``. In this + # case the output is a basic python type (bool) which can be handled + # easily just with pybind11 without any special code. Attributes of + # the result object are copied to attributes of the onedal estimator + # object. + + self.constant_, self.fit_X_, self.fit_y_ = from_table( + result.data, X_t, y_t, like=X + ) + # The fit_X_ and fit_y_ attributes are not required and are generally + # discouraged. They are set in order to show the process of setting + # and returning array values (and is just an example). In setting + # return attributes, post processing of the values beyond conversion + # needed for sklearn must occur in the sklearnex estimator. + + def _create_model(self): + # While doing something rather trivial, this is closer to what may + # occur in other estimators which can generate models just in time. + # Necessary attributes are collected, converted to oneDAL tables + # and set to the oneDAL object. In general there should be a oneDAL + # model class defined with serialization and deserialization with a + # pybind11 interface. + + # When the model is a oneDAL object (see svm), it must maintain a + # pybind11 interface to the `serialization` and `deserialization` + # oneDAL routines for proper pickling of the oneDAL object. oneDAL + # tables must be converted to the array type of the fitted estimator + # for pickling/unpickling (see any incremental estimator pybind11 + # implementation). + + # This example just treats a oneDAL table as the model. + return to_table(self.constant_) + + @supports_queue + def predict(self, X, queue=None): + X_t = to_table(X) + if self._onedal_model is None: + self._onedal_model = self._create_model() + + params = {"fptype": X_t.dtype, "method": "dense", "constant": self.constant} + result = self.infer(params, self._onedal_model, X_t) + return from_table(result.data, like=X) diff --git a/onedal/dummy/dummy_onedal.hpp b/onedal/dummy/dummy_onedal.hpp new file mode 100644 index 0000000000..d3bc6c00c5 --- /dev/null +++ b/onedal/dummy/dummy_onedal.hpp @@ -0,0 +1,322 @@ +/******************************************************************************* +* Copyright Contributors to the oneDAL Project +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#pragma once + +#include "onedal/common.hpp" +#include "onedal/version.hpp" +#include "oneapi/dal/table/common.hpp" +#include "oneapi/dal/table/detail/homogen_utils.hpp" +#include "oneapi/dal/train.hpp" +#include "oneapi/dal/infer.hpp" +#include "oneapi/dal/detail/policy.hpp" + +namespace py = pybind11; + +namespace oneapi::dal { + +////////////////////////// Dummy oneDAL Algorithm ///////////////////////// +// These aspects fake the necessary API characteristics of an algorithm +// from the oneDAL repository. This example foregoes the indirections used +// with impl_ attributes characteristic of the oneDAL codebase and only +// show the necessary APIs. It is also as minimal as possible, dropping +// some required setters/getters for brevity. It also violates some rules +// with respect to protected/private, attributes, and compile time type +// checking. +// +// Files which are normally separated in oneDAL for clarity are merged here +// to provide an overview of what is necessary for interaction in sklearnex +// from a high level. +// +// To support oneDAL offloading, task, method and descriptor structs need +// to be defined from the algorithm's common.hpp. +// +// For various modes (e.g. training, inference), the requisite functors and +// result data structs need to be defined. Usually this is in *_types.hpp. +// For example, a 'compute' algorithm would have a compute_types.hpp +// +// Usually these aspects are all made available via the algorithm's header +// file located in oneapi/dal/algo. +// +// This should act as a guide for where to look and what to reference in +// oneDAL for making a pybind11 interface. + +/////////////////////////////// common.hpp //////////////////////////////// +namespace dummy { + +namespace task { +// tasks can be arbitrarily named, ``by_default`` must be defined. +struct generate {}; +using by_default = generate; +} // namespace task + +namespace method { +// methods can be arbitrarily named, though this will be used in the +// python onedal estimator as a parameter +struct dense {}; +using by_default = dense; +} // namespace method + +namespace detail { +// This is highly important for central use of train, compute, infer etc. +// but is not used in sklearnex (and must be included here). +struct descriptor_tag {}; + +} // namespace detail + +template +class descriptor : public base { +public: + using tag_t = detail::descriptor_tag; + using float_t = Float; + using method_t = Method; + using task_t = Task; + + descriptor() : constant(0.0) {} + + double get_constant() const { + return this->constant; + } + + auto& set_constant(double value) { + this->constant = value; + return *this; + } + + // normally this attribute is hidden in another struct +private: + double constant; +}; +} // namespace dummy +/////////////////////////////// common.hpp //////////////////////////////// + +///////////////////////////// train_types.hpp ///////////////////////////// + +namespace dummy { + +template +class train_result { +public: + using task_t = Task; + + train_result() {} + + const table& get_data() { + return this->data; + } + + auto& set_data(const table& value) { + data = value; + return *this; + } + + // attribute usually hidden in an infer_result_impl class +private: + table data; +}; + +template +class train_input : public base { +public: + using task_t = Task; + + train_input(const table& data) : data(data) {} + + // attributes usually hidden in an infer_input_impl class with getters + // and setters. + table data; +}; +} // namespace dummy + +///////////////////////////// train_types.hpp ///////////////////////////// + +///////////////////////////// infer_types.hpp ///////////////////////////// +namespace dummy { +template +class infer_result { +public: + using task_t = Task; + + infer_result() {} + + const table& get_data() { + return this->data; + } + + auto& set_data(const table& value) { + data = value; + return *this; + } + + // attribute usually hidden in an infer_result_impl class +private: + table data; +}; + +template +class infer_input : public base { +public: + using task_t = Task; + + infer_input(const table& data, const table& constant) : data(data), constant(constant) {} + // setters and getters for ``data`` and ``model`` removed for brevity + + // attributes usually hidden in an infer_input_impl class with getters + // and setters. + table data; + table constant; +}; +} // namespace dummy +///////////////////////////// infer_types.hpp ///////////////////////////// + +/////// THESE ARE PRIVATE STEPS REQUIRED FOR IT TO WORK WITH ONEDAL /////// + +using dal::detail::host_policy; +#ifdef ONEDAL_DATA_PARALLEL +using dal::detail::data_parallel_policy; +#endif + +template +dal::homogen_table create_full_table(const host_policy& ctx, + std::int64_t row_c, + std::int64_t col_c, + float_t val) { + dal::array array = dal::array::full(col_c * row_c, val); + return dal::homogen_table::wrap(array, row_c, col_c); +} + +#ifdef ONEDAL_DATA_PARALLEL +template +dal::homogen_table create_full_table(const data_parallel_policy& ctx, + std::int64_t row_c, + std::int64_t col_c, + float_t val) { + auto queue = ctx.get_queue(); + dal::array array = dal::array::full(queue, col_c * row_c, val); + return dal::homogen_table::wrap(array, row_c, col_c); +} +#endif + +////////////////////////////// train_ops.hpp ////////////////////////////// +namespace dummy { +namespace detail { + +template +struct train_ops { + using float_t = typename Descriptor::float_t; + using task_t = typename Descriptor::task_t; + using method_t = method::by_default; + using input_t = train_input; + using result_t = train_result; + + template + auto operator()(const policy& ctx, const Descriptor& desc, const input_t& input) const { + // Usually a train_ops_dispatcher is contained in oneDAL train_ops.cpp. + // Due to the simplicity of this algorithm, implement it here. + auto col_c = input.data.get_column_count(); + result_t result; + result.set_data( + create_full_table(ctx, 1, col_c, static_cast(desc.get_constant()))); + return result; + } +}; +} // namespace detail +} // namespace dummy +////////////////////////////// train_ops.hpp ////////////////////////////// + +//////////////////////////////// train.hpp //////////////////////////////// + +namespace detail { +namespace v1 { + +template +struct train_ops + : dal::dummy::detail::train_ops {}; +} // namespace v1 +} // namespace detail + +//////////////////////////////// train.hpp //////////////////////////////// + +////////////////////////////// infer_ops.hpp ////////////////////////////// +namespace dummy { +namespace detail { + +template +struct infer_ops { + using float_t = typename Descriptor::float_t; + using task_t = typename Descriptor::task_t; + using method_t = method::by_default; + using input_t = infer_input; + using result_t = infer_result; + + template + auto operator()(const policy& ctx, const Descriptor& desc, const input_t& input) const { + // Usually a infer_ops_dispatcher is contained in oneDAL infer_ops.cpp. + // Due to the simplicity of this algorithm, implement it here. + auto row_c = input.data.get_row_count(); + auto col_c = input.constant.get_column_count(); + assert(input.get_kind() == dal::homogen_table::kind()); + result_t result; + const byte_t* ptr = + dal::detail::get_original_data(static_cast(input.constant)) + .get_data(); + + float_t inp; + // only switch those types which can be converted from python to dal tables + switch (input.constant.get_metadata().get_data_type(0)) { + case dal::data_type::float32: { + inp = static_cast(*reinterpret_cast(ptr)); + break; + } + case dal::data_type::float64: { + inp = static_cast(*reinterpret_cast(ptr)); + break; + } + case dal::data_type::int32: { + inp = static_cast(*reinterpret_cast(ptr)); + break; + } + case dal::data_type::int64: { + inp = static_cast(*reinterpret_cast(ptr)); + break; + default: throw std::runtime_error("incompatible input type"); + } + } + result.set_data(create_full_table(ctx, row_c, col_c, inp)); + return result; + } +}; +} // namespace detail +} // namespace dummy +////////////////////////////// infer_ops.hpp ////////////////////////////// + +//////////////////////////////// infer.hpp //////////////////////////////// + +namespace detail { +namespace v1 { + +template +struct infer_ops + : dal::dummy::detail::infer_ops {}; +} // namespace v1 +} // namespace detail + +//////////////////////////////// infer.hpp //////////////////////////////// + +////////////////////////// Dummy oneDAL Algorithm ///////////////////////// +} // namespace oneapi::dal diff --git a/setup.py b/setup.py index aa13495d1c..a577cb868c 100644 --- a/setup.py +++ b/setup.py @@ -546,6 +546,7 @@ class build(onedal_build, orig_build.build): "onedal.covariance", "onedal.datatypes", "onedal.decomposition", + "onedal.dummy", "onedal.ensemble", "onedal.neighbors", "onedal.primitives", @@ -556,6 +557,7 @@ class build(onedal_build, orig_build.build): "sklearnex.cluster", "sklearnex.covariance", "sklearnex.decomposition", + "sklearnex.dummy", "sklearnex.ensemble", "sklearnex.glob", "sklearnex.linear_model", diff --git a/sklearnex/__init__.py b/sklearnex/__init__.py index af09d9bdb4..0e01052c9a 100755 --- a/sklearnex/__init__.py +++ b/sklearnex/__init__.py @@ -34,6 +34,7 @@ "config_context", "covariance", "decomposition", + "dummy", "ensemble", "get_config", "get_hyperparameters", diff --git a/sklearnex/dispatcher.py b/sklearnex/dispatcher.py index ae77b6c73e..b63c8c38a5 100644 --- a/sklearnex/dispatcher.py +++ b/sklearnex/dispatcher.py @@ -101,7 +101,13 @@ def get_patch_map_core(preview=False): import sklearn.cluster as cluster_module import sklearn.covariance as covariance_module import sklearn.decomposition as decomposition_module + import sklearn.dummy as dummy_module import sklearn.ensemble as ensemble_module + + if sklearn_check_version("1.4"): + import sklearn.ensemble._gb as _gb_module + else: + import sklearn.ensemble._gb_losses as _gb_module import sklearn.linear_model as linear_model_module import sklearn.manifold as manifold_module import sklearn.metrics as metrics_module @@ -124,6 +130,7 @@ def get_patch_map_core(preview=False): IncrementalEmpiricalCovariance as IncrementalEmpiricalCovariance_sklearnex, ) from .decomposition import PCA as PCA_sklearnex + from .dummy import DummyRegressor as DummyRegressor_sklearnex from .ensemble import ExtraTreesClassifier as ExtraTreesClassifier_sklearnex from .ensemble import ExtraTreesRegressor as ExtraTreesRegressor_sklearnex from .ensemble import RandomForestClassifier as RandomForestClassifier_sklearnex @@ -408,6 +415,31 @@ def get_patch_map_core(preview=False): ] ] + # DummyRegressor + mapping["dummyregressor"] = [ + [ + ( + dummy_module, + "DummyRegressor", + DummyRegressor_sklearnex, + ), + None, + ] + ] + + # Required patching of DummyRegressor in the gradient boosting + # module as it is used in the GradientBoosting algorithms + mapping["gb_dummyregressor"] = [ + [ + ( + _gb_module, + "DummyRegressor", + DummyRegressor_sklearnex, + ), + None, + ] + ] + # Configs mapping["set_config"] = [ [(base_module, "set_config", set_config_sklearnex), None] diff --git a/sklearnex/dummy/__init__.py b/sklearnex/dummy/__init__.py new file mode 100644 index 0000000000..2cfeb26e32 --- /dev/null +++ b/sklearnex/dummy/__init__.py @@ -0,0 +1,19 @@ +# =============================================================================== +# Copyright Contributors to the oneDAL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =============================================================================== + +from ._dummy import DummyRegressor + +__all__ = ["DummyRegressor"] diff --git a/sklearnex/dummy/_dummy.py b/sklearnex/dummy/_dummy.py new file mode 100644 index 0000000000..9b374b13b9 --- /dev/null +++ b/sklearnex/dummy/_dummy.py @@ -0,0 +1,615 @@ +# ============================================================================== +# Copyright Contributors to the oneDAL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Sklearnex module estimator design guide and example. + +This can be used as a foundation for developing other estimators. Most +comments guiding code development should be removed if reused unless +pertinent to the derivative implementation.""" +import numpy as np +import scipy.sparse as sp +from sklearn.dummy import DummyRegressor as _sklearn_DummyRegressor +from sklearn.utils.validation import check_is_fitted + +from daal4py.sklearn._n_jobs_support import control_n_jobs +from daal4py.sklearn._utils import daal_check_version, sklearn_check_version +from onedal._device_offload import support_input_format +from onedal.dummy import DummyEstimator as onedal_DummyEstimator + +from .._device_offload import dispatch +from .._utils import PatchingConditionsChain +from ..base import oneDALEstimator +from ..utils._array_api import enable_array_api, get_namespace +from ..utils.validation import validate_data + +################ +# IMPORT NOTES # +################ +# +# 1) All sklearnex estimators must inherit oneDALestimator and the sklearn +# estimator that it is replicating (i.e. before in the mro). If there is +# not an equivalent sklearn estimator, then sklearn's BaseEstimator must be +# inherited. +# +# 2) ``check_is_fitted`` is required for any method in an estimator which +# requires first calling ``fit`` or ``partial_fit``. This is a sklearn +# requirement. +# +# 3) Every estimator should be decorated by ``control_n_jobs`` to properly +# create parallelization control for the oneDAL library via the ``n_jobs`` +# parameter. This parameter is added to ``__init__`` automatically. +# +# 4) For compatibility reasons, ``daal_check_version`` and +# ``sklearn_check_version`` add or remove capabilities based on the installed +# oneDAL library and scikit-learn package. This is often necessary for the +# state of oneDAL development and scikit-learn characteristics. This should +# be used at import time instead of run time whenever possible/ practical. +# +# 5) If a sklearn estimator is imported, it must have the ``_sklearn_`` +# prefix added upon import in order to prevent its discovery, highlight +# its nature as private, and prevent a namespace collision. Any onedal +# imported estimator should similarly have the ``onedal_`` prefix added +# (as it should have the same name as the sklearnex estimator). +# +# 6) ``dispatch`` is a key central function for evaluating data with either +# oneDAL or sklearn. All oneDAL algorithms which are to be directly used +# should be accessed via this function. It should not be used unless a +# call to a onedal estimator occurs. +# +# 7) ``PatchingConditionsChain`` is used in conjunction with ``dispatch`` +# and methods ``_onedal_cpu_supported`` and ``_onedal_gpu_supported`` to +# evaluate if the required evaluation on data is possible with oneDAL or +# sklearn. +# +# 8) ``get_namespace`` is key for array_api support, which yields the +# namespace associated with the given array for use in data conversion +# necessary for to and from oneDAL. An internal version is preferred due to +# limitations in sklearn versions and specific DPEX data framework support +# (see dpctl tensors and dpnp). +# +# 9) ``validate_data`` checks data quality and estimator status before +# evaluating the function. This replicates a sklearn functionality with key +# performance changes implemented in oneDAL and therefore should only be +# imported from sklearnex and not sklearn. +# +# 10) All estimators require validation of the parameters given at +# initialization. This aspect was introduced in sklearn 1.2, any additional +# parameters must extend the dictionary for checking. This validation +# normally occurs in the ``fit`` method. +# + +########################## +# METHOD HIERARCHY NOTES # +########################## +# +# Sklearnex estimator methods can be thought of in 3 major tiers. +# +# Tier 1: Methods which offload to oneDAL using ``dispatch``. Typical +# examples are ``fit`` and ``predict``. They use a direct equivalent oneDAL +# function for evaluation. These methods are of highest priority and have +# performance benchmark requirements. +# +# Tier 2: Methods that use a Tier 1 method with additional Python +# calculations (usually a sklearn method or applied math function). Examples +# are ``kneighbors_graph`` and ``predict_log_proba``. Oftentimes the +# additional calculations are trivial, meaning benchmarking is not required. +# +# Tier 3: Methods which directly use sklearn functionality. Typically these +# can be directly inherited, but can be problematic with respect to other +# framework support. These can be wrapped with the sklearnex function +# ``wrap_output_data`` to guarantee array API, dpctl tensor, and dpnp +# support but should be addressed with care/guidance in a case-by-case +# basis. +# +# When the sklearnex method is replacing an inherited sklearn method, it +# must match the method signature exactly. For sklearnex-only estimators, +# attempt to match convention to sklearn estimators which are closely related. + +######################## +# CONTROL_N_JOBS NOTES # +######################## +# +# All tier 1 methods should be in the decorated_methods list for oneDAL +# parallelism control. In general, changes to oneDAL parallelism should only +# be done once per public method call. This may mean some tier 2 methods +# must be added to the list along with some restructuring of the related +# tier 1 methods. An illustrative example could be an estimator which +# implements ``fit_transform`` where combining ``fit`` and ``transform`` +# tier 1 methods may set n_jobs twice. + + +# enable_array_api enables the sklearnex code to work with and directly pass +# array_api and dpep frameworks data (dpnp, dpctl tensors, and pytorch for +# example) to the oneDAL backend +@enable_array_api +@control_n_jobs(decorated_methods=["fit", "predict"]) +class DummyRegressor(oneDALEstimator, _sklearn_DummyRegressor): + # All sklearnex estimators must inherit a sklearn estimator, sklearnex- + # only estimators are shown by the inheritance of sklearn's + # BaseEstimator. Additionally, inherited oneDALEstimator for estimators + # without a sklearn equivalent must occur directly before BaseEstimator + # in the mro. + + ################################## + # GENERAL ESTIMATOR DESIGN NOTES # + ################################## + # + # As a rule conform to sklearn design rules as much as possible + # (https://scikit-learn.org/stable/developers/develop.html) + # This includes inheriting the proper sklearn Mixin classes depending + # on the sklearnex estimator functionality. + # + # All estimators should be defined in a Python file located in a folder + # limited to the folder names in this directory: + # https://github.com/scikit-learn/scikit-learn/tree/main/sklearn + # All estimators should be properly added into the patching map located + # in sklearnex/dispatcher.py following the convention made there. This + # is important for having the estimator properly tested and available + # in sklearn. + # + # Sklearnex estimators follow a Matryoshka doll pattern with respect to + # the underlying oneDAL library: + # + # - The sklearnex estimator is a public-facing API which mimics sklearn. + # + # - The onedal estimator object which determines and sets characteristics + # with respect to oneDAL offloading (including but not limited to + # parameters, SYCL queue, data and result conversion). + # + # - The pybind11 interface objects to oneDAL C++ objects (inputs, + # results, and methods). + # + # These 3 objects interact in the following way: sklearnex estimators + # will create another estimator, defined in the ``onedal`` module, for + # having a Python interface with oneDAL. Finally, this Python object + # will use pybind11 to call oneDAL directly via pybind11-generated + # objects and functions This is known as the ``backend``. These are + # separate entities and do not inherit from one another. The clear + # separation has utility so long that the following rules are followed: + # + # 1) All conformance to sklearn should happen in sklearnex estimators, + # with all variations between the supported sklearn versions handled + # there. This includes transforming result data into a format which + # matches sklearn. This is done to minimize and focus maintenance with + # respect to sklearn to the sklearnex module. + # + # 2) The onedal estimator handles necessary data conversion and + # preparation for invoking calls to onedal. These objects should not be + # influenced by sklearn design or have any sklearn version dependent + # characteristics. Users should be able to use these objects directly + # to fit data without sklearn, giving the ability to use raw data + # directly and avoiding sklearn pre-processing checks as necessary. + # + # 3) Pybind11 interfaces should not be made public to the user unless + # absolutely necessary, as operation there assumes checks in the other + # objects have been sufficiently carried out. In most circumstances, the + # pybind11 interface should be invoked by the Python onedal estimator + # object. + # + # 4) If the estimator replicates/inherits from a sklearn estimator, + # then only implemented public methods should be those which override + # those from the sklearn estimator. The sklearn method should only be + # overridden if an equivalent oneDAL-accelerated capability exists + # following the tier system described below. If it is sklearnex only, + # then it should try to follow sklearn conventions of sklearn estimators + # which are most closely related (e.g. IncrementalPCA for incremental + # estimators). NOTE: as per the sklearn design rules, all estimator + # attributes with trailing underscores are return values and are of + # some type of data (and therefore not themselves oneDAL-accelerated). + # + # 5) Fitted attributes of the related scikit-learn estimator which are not + # defined or calculated by the oneDAL estimator still must be set in the + # scikit-learn-intelex estimator to plausible values. These can be either + # derived from available oneDAL estimator data or set to hardcoded values. + # + # + # Information about the onedal estimators/objects can be found in an + # equivalent class file in the onedal module. + + ####################### + # DOCUMENTATION NOTES # + ####################### + # + # All public methods (i.e. without leading underscores) should have + # documentation which conforms to the numpy-doc standard. Generally + # if a defined method replaces an inherited Scikit-Learn estimator + # method, the ``__doc__`` attribute should be re-applied to the new + # implementation. Any new additional characteristics compared to the + # equivalent sklearn estimator should be appended to the sklearn doc + # string. + # + # When the estimator is added to the patching map in + # sklearnex/dispatcher.py, it must be equivalently added to the support + # table located in doc/sources/algorithms.rst if replicating an sklearn + # estimator. If it is unique to sklearnex, it must be added to + # docs/sources/non-scikit-algorithms.rst instead. + + # This is required as part of sklearn conformance, which does checking + # of parameters set in __init__ when calling self.validate_params (should + # only be in a fit or fit-derived call) + if sklearn_check_version("1.2"): + _parameter_constraints: dict = {**_sklearn_DummyRegressor._parameter_constraints} + + def __init__(self, *, strategy="mean", constant=None, quantile=None): + # Object instantiation is strictly limited by sklearn. It is only + # allowed to take the keyword arguments and store them as + # attributes with the same name. When replicating a sklearn + # estimator, it may be possible to use the inherited version of + # ``__init__`` from sklearn. The prototype uses defined parameters + # to highlight the way parameters are set. Controlled by sklearn + # test_common.py testing. + # + # The signature of the __init__ must match the sklearn estimator + # that it replicates (and is verified in test_patching.py) + self.strategy = strategy + self.constant = constant + self.quantile = quantile + + # To generalize for spmd and other use cases, the constructor of the + # onedal estimator should be set as an attribute of the class + _onedal_DummyEstimator = staticmethod(onedal_DummyEstimator) + + ############################ + # TIER 1 METHOD FLOW NOTES # + ############################ + # + # Some knowledge of the process flow from the sklearnex perspective is + # necessary to understand how to implement an estimator. For Tier 1 + # methods, the general process is as follows: + # + # 1) If a method which requires a fitted estimator, the method must + # call ``check_is_fitted`` before calling ``dispatch``. This verifies + # that aspects of the fit are available for analysis (whether oneDAL + # may be used or not), usually this means specific instance attributes + # with trailing underscores. + # + # 2) ``dispatch`` is called. This takes the estimator object, method + # name, and the two possible evaluation branches and proceeds to call + # ``_onedal_gpu_supported`` if a SYCL queue is found or set via the + # target offload config. Otherwise ``_onedal_cpu_supported`` is + # called. + # + # 3) ``_onedal_gpu_supported`` or ``_onedal_cpu_supported`` creates a + # PatchingConditionsChain object, takes the input data and estimator + # parameters, and evaluates whether the estimator and data can be run + # using oneDAL. This information is logged to the `sklearnex` logger + # via central code (e.g. not by the estimator) in sklearnex. + # + # 4) Either sklearn is called, or a object from onedal is created and + # called using the input data. This process is handled in a function + # which has the prefix "_onedal_" followed by the method name. When + # fitting data, the returned onedal estimator object is stored as the + # ``_onedal_estimator`` attribute. + # + # 5) Result data is returned from the estimator if necessary. Attributes + # from the onedal estimator are copied over to the sklearnex estimator. + + def fit(self, X, y, sample_weight=None): + # Parameter validation must be done before calls to dispatch. This + # guarantees that the sklearn and onedal use of parameters are + # properly typed and valued. + if sklearn_check_version("1.2"): + self._validate_params() + + # only arguments from the method signature are passed to + # ``_onedal_*_supported`` and not kwargs. The parameters of the + # estimator are available by default as they are instance attributes. + # The choice between sklearn and onedal is based off of the args, + # and not the keyword arguments. + dispatch( + self, + "fit", + { + "onedal": self.__class__._onedal_fit, + "sklearn": _sklearn_DummyRegressor.fit, + }, + X, + y, + sample_weight, + ) + # For sklearnex-only estimators, _onedal_*_supported should either + # pass or throw an exception. This means the sklearn branch is never + # used. In general, the two branches must be the class methods. The + # parameters which are passed as arguments are given to + # _onedal_*_supported. In this example, the ``sample_weight`` keyword + # argument in the ``fit`` signature is set as an argument to + # ``dispatch`` so that it can be properly sent to _onedal_*_supported + # for checking oneDAL support. + + # methods which do not return a result should return self (sklearn + # standard) + return self + + def predict(self, X, return_std=False): + # note return_std is a special aspect of the sklearn version of this + # estimator, normally the signatures is just predict(self, X) + + check_is_fitted(self) # first check if fitting has occurred + # No need to do another parameter check. While they are modifiable + # in sklearn and in sklearnex, the parameters should never be + # changed by hand. + return dispatch( + self, + "predict", + { + "onedal": self.__class__._onedal_predict, + "sklearn": _sklearn_DummyRegressor.predict, + }, + X, + return_std=return_std, # not important for patching, set as kwarg + ) + # return value will be handled by self._onedal_predict + + def _onedal_fit(self, X, y, sample_weight=None, queue=None): + # The queue attribute must be added as the last kwarg to all + # onedal-facing functions. The SYCL queue is acquired in + # ``dispatch`` and is set there before calling ``_onedal_``-prefix + # methods. + + # The first step is to always acquire the namespace of input data + # This is important for getting the proper data types and possibly + # other conversions. + xp, _ = get_namespace(X, y) + + # The second step must always be to validate the data. + # This algorithm can accept 2d y inputs (by setting multi_output). + # Note the use of "sklearn_check_version". This is required in order + # to conform to changes which occur in sklearn over the supported + # versions. The conformance to sklearn should occur in this object, + # therefore this function should not be used in the onedal module. + # This conformance example is specific to the Dummy Estimators. + X, y = validate_data( + self, + X, + y, + dtype=[xp.float64, xp.float32], + multi_output=True, + y_numeric=True, + ensure_2d=sklearn_check_version("1.2"), + ) + # validate_data does several things: + # 1) If not in the proper namespace (depending on array_api configs) + # convert the data to the proper data format (default: numpy array) + # 2) It will check additional aspects for labeled data. + # 3) It will convert the arrays to the proper data type, which for + # oneDAL is usually float64 or float32, but can also be int32 in + # rare circumstances. + # kwargs often are used for sklearn's ``check_array``. It is best + # to often use the values set for sklearn for the equivalent same + # step. This is not guaranteed and requires care by the developer. + # For example, ``ensure_all_finite`` is set to false in this case + # for the nature of the class, but would otherwise be unset. + + # Conformance to sklearn's DummyRegressor + if y.ndim == 1: + y = xp.reshape(y, (-1, 1)) + self.n_outputs_ = y.shape[1] + + # In the ``fit`` method, a Python onedal estimator object is + # generated. + self._onedal_estimator = self._onedal_DummyEstimator(constant=self.constant) + # queue must be passed to the onedal Python estimator object + # though this may change in the future as a requirement. + self._onedal_estimator.fit(X, y, queue=queue) + + # set attributes from _onedal_estimator to sklearnex estimator + # It is allowed to have a separate private function to do this step + # Below is only an example, but should be all the attributes + # available from the same sklearn estimator (if not sklearnex-only) + # after fitting. + self.constant_ = self._onedal_estimator.constant_ + # See sklearn conventions about trailing underscores for fitted + # values. + + # sklearn conformance + if self.n_outputs_ != 1 and self.constant_.shape[0] != y.shape[1]: + raise ValueError( + "Constant target value should have shape (%d, 1)." % y.shape[1] + ) + + def _onedal_predict(self, X, return_std=None, queue=None): + # The first step is to always acquire the namespace of input data + # This is important for getting the proper data types and possibly + # other conversions. + xp, _ = get_namespace(X) + + # The second step must always be to validate the data. + # Not checking of X as 2d is sklearn conformance specific to matching + # the Scikit-Learn DummyRegressor and is not normally required. + X = validate_data( + self, + X, + dtype=[xp.float64, xp.float32], + reset=False, + ensure_2d=sklearn_check_version("1.2"), + ) + # queue must be sent back to the onedal Python estimator object + y = self._onedal_estimator.predict(X, queue=queue) + + if self.n_outputs_ == 1: + y = xp.reshape(y, (-1,)) + + y_std = xp.zeros_like(y) + + return (y, y_std) if return_std else y + + def _onedal_cpu_supported(self, method_name, *data): + # All estimators must have the following two functions with exactly + # these signatures. method_name is a string which must match one + # of the tier 1 methods of the estimator. The logic located here + # will inspect attributes of the data and the estimator to see if + # sklearn + + # Begin by generating the PatchingConditionsChain, which should + # require modifying the secondary module to match the folder as in + # the example below. + patching_status = PatchingConditionsChain( + f"sklearnex.test.{self.__class__.__name__}.{method_name}" + ) + # The conditions are specifically tailored to compares aspects + # of the oneDAL implementation to the aspects of the sklearn + # estimator. For example, oneDAL may not support sparse inputs + # where sklearn might, that would need to be checked with + # scipy.sparse.issparse(X). In general the conditions will + # correspond to information in the metadata and/or the estimator + # parameters. + # + # In no circumstance should ``validate_data`` be called here or + # in _onedal_gpu_supoorted to get the data into the proper form. + if method_name == "fit": + (X, y, sample_weight) = data + xp, _ = get_namespace(X, y) + + # the PatchingConditionsChain is validated using + # ``and_conditions``, use of ``or_conditions`` is highly + # discouraged. The following checks are specific to this example + # and must be tailored to the specific estimator implementation. + patching_status.and_conditions( + [ + ( + not sp.issparse(X), + "sparse data is not supported", + ), + ( + self.strategy == "constant", + "only the constant strategy is supported", + ), + ( + not hasattr(X, "dtype") or X.dtype in (xp.float64, xp.float32), + "oneDAL operates with float64 and float32 inputs", + ), + ( + isinstance(self.constant, (int, float)), + "only basic Python types are supported", + ), + (sample_weight is None, "sample_weight is not supported"), + ] + ) + + elif method_name == "predict": + # There is a very important subtlety about the ``dispatch`` function + # and how it interacts with ``_onedal_*_supported`` in that only args + # are used in these methods to evaluate oneDAL support. This means + # that kwargs to the public API may become args in the call to dispatch + # In this case, return_std (from predict) does not impact oneDAL, and + # is kept as a kwarg in the ``dispatch`` call in ``predict``. In ``fit`` + # the kwarg ``sample_weight`` is important for evaluating oneDAL support + # and is passed as an arg. + (X,) = data + + patching_status.and_conditions( + [ + (hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."), + ( + not sp.issparse(X), + "sparse data is not supported", + ), + ] + ) + + # the patching_status object should be returned + return patching_status + + def _onedal_gpu_supported(self, method_name, *data): + # This method will only be called if it is expected to try and use + # a SYCL-enabled GPU. See _onedal_cpu_supported for initial + # implementation notes. This should follow the same procedures + # dicatated by the characteristics of GPU oneDAL algorithm + patching_status = PatchingConditionsChain( + f"sklearnex.test.{self.__class__.__name__}.{method_name}" + ) + if method_name == "fit": + (X, y, sample_weight) = data + xp, _ = get_namespace(X, y) + + patching_status.and_conditions( + [ + ( + not sp.issparse(X), + "sparse data is not supported", + ), + ( + self.strategy == "constant", + "only the constant strategy is supported", + ), + ( + not hasattr(X, "dtype") or X.dtype in (xp.float64, xp.float32), + "oneDAL operates on float64 and float32 inputs", + ), + ( + isinstance(self.constant, (int, float)), + "only basic Python types are supported", + ), + (sample_weight is None, "sample_weight is not supported"), + ] + ) + + elif method_name == "predict": + (X,) = data + + patching_status.and_conditions( + [ + (hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."), + ( + not sp.issparse(X), + "sparse data is not supported", + ), + ] + ) + + # the patching_status object should be returned + return patching_status + + # onedal estimators with onedal models which can be modified must have + # the necessary attributes linked. This way the state of the two + # estimators do not diverge as modifications could impact the inference + # results. This not always necessary, as some estimators generate a + # model for predict during fit which cannot be modified. The easiest + # way to check for this is if the oneDAL estimator contains a "model" + # method. + + @property + def constant_(self): + return self._constant_ + + # The fitted variables match the data framework and device of the inputs. + # Modifications of these output attributes to different frameworks or + # devices may not work and are not monitored by the oneDAL estimator. + # This matches behavior which occurs in sklearn and therefore is up + # to the user to guarantee operation, especially in methods which depend + # on fitted estimators and attributes (like `transform` or `predict`). + @constant_.setter + def constant_(self, value): + self._constant_ = value + if hasattr(self, "_onedal_estimator"): + self._onedal_estimator._onedal_model = None + self._onedal_estimator.constant_ = value + + @constant_.deleter + def constant_(self): + del self._constant_ + + # score is a tier 3 method in this case. Wrap with ``support_input_format`` for array + # api support. + score = support_input_format(_sklearn_DummyRegressor.score) + + # Docstrings should be inherited from the sklearn estimator if possible + # In sklearnex-only estimators, they should be written from scratch + # using the numpy-doc standard. + __doc__ = _sklearn_DummyRegressor.__doc__ + fit.__doc__ = _sklearn_DummyRegressor.fit.__doc__ + predict.__doc__ = _sklearn_DummyRegressor.predict.__doc__ + score.__doc__ = _sklearn_DummyRegressor.score.__doc__ diff --git a/sklearnex/dummy/tests/test_dummy.py b/sklearnex/dummy/tests/test_dummy.py new file mode 100644 index 0000000000..7b88d4eaed --- /dev/null +++ b/sklearnex/dummy/tests/test_dummy.py @@ -0,0 +1,62 @@ +# =============================================================================== +# Copyright Contributors to the oneDAL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =============================================================================== + +import numpy as np +import pytest + +from daal4py.sklearn._utils import sklearn_check_version +from onedal.tests.utils._dataframes_support import ( + _as_numpy, + _convert_to_dataframe, + get_dataframes_and_queues, +) +from sklearnex import config_context +from sklearnex.dummy import DummyRegressor + + +@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues()) +def test_sklearnex_import_DummyRegressor(dataframe, queue): + rng = np.random.default_rng(seed=42) + + X = rng.random((10, 4)) + y = rng.random((10,)) + X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe) + y = _convert_to_dataframe(y, sycl_queue=queue, target_df=dataframe) + est = DummyRegressor(strategy="constant", constant=np.pi).fit(X, y) + assert "sklearnex" in est.__module__ + pred = _as_numpy(est.predict([[0, 0, 0, 0]])) + np.testing.assert_array_equal(np.pi * np.ones(pred.shape), pred) + + +@pytest.mark.skipif( + not sklearn_check_version("1.3"), reason="lacks sklearn array API support" +) +@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues("dpctl,dpnp")) +def test_fitted_attribute_conversion_DummyRegressor(dataframe, queue): + rng = np.random.default_rng(seed=42) + + X = rng.random((10, 4)) + y = rng.random((10,)) + X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe) + y = _convert_to_dataframe(y, sycl_queue=queue, target_df=dataframe) + X_test = _convert_to_dataframe([[0, 0, 0, 0]], sycl_queue=queue, target_df=dataframe) + with config_context(array_api_dispatch=True): + est = DummyRegressor(strategy="constant", constant=np.e).fit(X, y) + pred = _as_numpy(est.predict(X_test)) + + np.testing.assert_array_equal(np.full(pred.shape, np.e), pred) + est.constant_ = np.ones(est.constant_.shape) + np.testing.assert_array_equal(np.ones(pred.shape), est.predict([[0, 0, 0, 0]])) diff --git a/sklearnex/tests/test_common.py b/sklearnex/tests/test_common.py index d8e3cb8188..d01597344d 100644 --- a/sklearnex/tests/test_common.py +++ b/sklearnex/tests/test_common.py @@ -103,6 +103,9 @@ "LogisticRegression(solver='newton-cg')-predict-n_jobs_check": "uses daal4py for cpu in sklearnex", "LogisticRegression(solver='newton-cg')-predict_log_proba-n_jobs_check": "uses daal4py for cpu in sklearnex", "LogisticRegression(solver='newton-cg')-predict_proba-n_jobs_check": "uses daal4py for cpu in sklearnex", + "DummyRegressor-fit-n_jobs_check": "default parameters use sklearn", + "DummyRegressor-predict-n_jobs_check": "default parameters use sklearn", + "DummyRegressor-score-n_jobs_check": "default parameters use sklearn", # KNeighborsClassifier validate_data issues - will be fixed later "KNeighborsClassifier-fit-call_validate_data": "validate_data implementation needs fixing", "KNeighborsClassifier-predict_proba-call_validate_data": "validate_data implementation needs fixing", diff --git a/sklearnex/tests/test_memory_usage.py b/sklearnex/tests/test_memory_usage.py index f6f84b45ae..17d53e67c8 100644 --- a/sklearnex/tests/test_memory_usage.py +++ b/sklearnex/tests/test_memory_usage.py @@ -55,6 +55,7 @@ "IncrementalPCA", # TODO fix memory leak issue in private CI for data_shape = (1000, 100), data_transform_function = dataframe_f "IncrementalRidge", # TODO fix memory leak issue in private CI for data_shape = (1000, 100), data_transform_function = dataframe_f "LogisticRegression(solver='newton-cg')", # memory leak fortran (1000, 100) + "DummyRegressor", # default parameters not supported ) GPU_SKIP_LIST = ( @@ -73,6 +74,7 @@ "NuSVC(probability=True)", # does not support GPU offloading (fails silently) "IncrementalLinearRegression", # issue with potrf with the specific dataset "LinearRegression", # issue with potrf with the specific dataset + "DummyRegressor", # default parameters not supported ) diff --git a/sklearnex/tests/test_n_jobs_support.py b/sklearnex/tests/test_n_jobs_support.py index 45d63d8e90..d61546b103 100644 --- a/sklearnex/tests/test_n_jobs_support.py +++ b/sklearnex/tests/test_n_jobs_support.py @@ -74,6 +74,9 @@ def test_n_jobs_method_decoration(estimator): @pytest.mark.parametrize("n_jobs", [None, -1, 1, 2]) def test_n_jobs_support(estimator, n_jobs, caplog): + if estimator == "DummyRegressor": + pytest.skip("default parameters fall back to sklearn") + est = _get_estimator_instance(estimator) caplog.set_level(logging.DEBUG, logger="sklearnex") diff --git a/sklearnex/tests/test_patching.py b/sklearnex/tests/test_patching.py index c5b78bbe10..7c8db3232d 100755 --- a/sklearnex/tests/test_patching.py +++ b/sklearnex/tests/test_patching.py @@ -318,9 +318,17 @@ def test_patch_map_match(): def list_all_attr(string): try: - modules = set(importlib.import_module(string).__all__) + mod = importlib.import_module(string) except ModuleNotFoundError: - modules = set([None]) + return set([None]) + + # Some sklearn estimators exist in python + # files rather than folders under sklearn + modules = set( + getattr( + mod, "__all__", [name for name in dir(mod) if not name.startswith("_")] + ) + ) return modules if _is_preview_enabled(): diff --git a/sklearnex/tests/test_run_to_run_stability.py b/sklearnex/tests/test_run_to_run_stability.py index dc362c955b..1a39ad4f24 100755 --- a/sklearnex/tests/test_run_to_run_stability.py +++ b/sklearnex/tests/test_run_to_run_stability.py @@ -179,6 +179,8 @@ def test_standard_estimator_stability(estimator, method, dataframe, queue): pytest.skip(f"variation observed in {estimator}.score") if estimator in ["IncrementalEmpiricalCovariance"] and method == "mahalanobis": pytest.skip("allowed fallback to sklearn occurs") + if estimator == "DummyRegressor": + pytest.skip("default parameters fall back to sklearn") _skip_neighbors(estimator, method) if "NearestNeighbors" in estimator and "radius" in method: diff --git a/sklearnex/tests/utils/base.py b/sklearnex/tests/utils/base.py index e2494182ee..2966d5c0fd 100755 --- a/sklearnex/tests/utils/base.py +++ b/sklearnex/tests/utils/base.py @@ -39,6 +39,7 @@ from onedal.utils._array_api import _get_sycl_namespace from sklearnex import get_patch_map, patch_sklearn, sklearn_is_patched, unpatch_sklearn from sklearnex.basic_statistics import BasicStatistics, IncrementalBasicStatistics +from sklearnex.dummy import DummyRegressor from sklearnex.linear_model import LogisticRegression from sklearnex.neighbors import ( KNeighborsClassifier, @@ -137,6 +138,7 @@ def __getitem__(self, key): LogisticRegression(solver="newton-cg"), BasicStatistics(), IncrementalBasicStatistics(), + DummyRegressor(strategy="constant", constant=1.0), # val set to 1 arbitrarily ] } )