Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
150 commits
Select commit Hold shift + click to select a range
3829036
initial information
icfaust Jun 11, 2025
1a4627c
more changes
icfaust Jun 11, 2025
90df3e3
missing space
icfaust Jun 11, 2025
c4a6fdc
interim changes
icfaust Jun 11, 2025
b54fe71
interim changes
icfaust Jun 11, 2025
befe555
interim update
icfaust Jun 16, 2025
8f08da0
step before adding tags and explanation
icfaust Jun 20, 2025
7c06762
this will probably require splitting into a separate PR
icfaust Jun 20, 2025
2ab1084
forgot return
icfaust Jun 20, 2025
b34e741
fixes for CI
icfaust Jun 20, 2025
7a4df2c
move to follow proper development
icfaust Jun 20, 2025
23a15d3
add docs
icfaust Jun 20, 2025
15bf789
add file
icfaust Jun 24, 2025
dbfc635
added documentation
icfaust Jun 24, 2025
eb48444
finish sentence
icfaust Jun 24, 2025
be72e87
finish sentence
icfaust Jun 24, 2025
b987736
updates
icfaust Jun 24, 2025
e3fa31b
updates
icfaust Jun 24, 2025
e8ceac5
Update prototypes.py
icfaust Jul 17, 2025
0cbe504
Update prototypes.py
icfaust Jul 17, 2025
c4e2283
Update prototypes.py
icfaust Jul 17, 2025
abba51d
move folders to re-orient the prototype\
icfaust Jul 28, 2025
a09e238
add many lines
icfaust Jul 28, 2025
60a8df1
add many lines
icfaust Jul 28, 2025
85bdf39
add many lines
icfaust Jul 28, 2025
cd54ace
stopping point
icfaust Jul 28, 2025
dc3774c
fix mistake
icfaust Jul 28, 2025
4a6ee7d
still planning things
icfaust Jul 28, 2025
9ba1966
try it this way
icfaust Jul 28, 2025
f0d5427
add dummy.cpp
icfaust Jul 28, 2025
877f963
going whole hog
icfaust Jul 29, 2025
d4c7526
Merge branch 'main' into dev/estimator_design_docs
icfaust Jul 31, 2025
940287b
add many changes
icfaust Aug 1, 2025
ce573bf
rename file for clarity
icfaust Aug 1, 2025
34a42f5
break
icfaust Aug 1, 2025
199a416
fix issues
icfaust Aug 1, 2025
98c9b36
fix issues
icfaust Aug 1, 2025
771e314
break
icfaust Aug 1, 2025
57ef9d2
more info
icfaust Aug 1, 2025
5774d20
more info
icfaust Aug 1, 2025
a514023
stopping point
icfaust Aug 1, 2025
9490f38
add some clarity on what the heck is happening in compilation
icfaust Aug 1, 2025
3986916
breakpoint
icfaust Aug 1, 2025
dfc6216
breakpoint
icfaust Aug 1, 2025
c953c69
first corrections
icfaust Aug 1, 2025
c6e8487
small fix
icfaust Aug 1, 2025
02df4c8
codespell fixes
icfaust Aug 1, 2025
0a55de2
further updates
icfaust Aug 1, 2025
541d054
ugh
icfaust Aug 1, 2025
c501a0c
fixes
icfaust Aug 1, 2025
8e3b88c
Merge branch 'dev/estimator_design_docs' of https://github.com/icfaus…
icfaust Aug 1, 2025
fa3444b
fix error
icfaust Aug 1, 2025
4f54b78
stopping point
icfaust Aug 2, 2025
4cf36db
Update dummy_onedal.hpp
icfaust Sep 8, 2025
0b2b62b
Update dummy_onedal.hpp
icfaust Sep 8, 2025
e8f3557
Update dummy_onedal.hpp
icfaust Sep 8, 2025
99c852b
Update dummy_onedal.hpp
icfaust Sep 8, 2025
978b811
Update dummy_onedal.hpp
icfaust Sep 8, 2025
156c78d
Update dummy_onedal.hpp
icfaust Sep 8, 2025
2a4df80
Update dummy_onedal.hpp
icfaust Sep 8, 2025
999cfc1
Update dummy.cpp
icfaust Sep 8, 2025
33986aa
Update dummy.cpp
icfaust Sep 8, 2025
3c86271
Update dummy.cpp
icfaust Sep 8, 2025
40d9a49
Update dummy.cpp
icfaust Sep 8, 2025
84afc60
Update dummy.cpp
icfaust Sep 8, 2025
98914e6
Update dummy_onedal.hpp
icfaust Sep 8, 2025
4f48012
Update dummy_onedal.hpp
icfaust Sep 8, 2025
d8472eb
Update dummy_onedal.hpp
icfaust Sep 8, 2025
72be71d
Merge branch 'uxlfoundation:main' into dev/estimator_design_docs
icfaust Sep 8, 2025
853a53f
Update dummy_onedal.hpp
icfaust Sep 8, 2025
435d52a
Update dummy.cpp
icfaust Sep 8, 2025
6bd3fe6
Update dummy_onedal.hpp
icfaust Sep 8, 2025
ed96f16
Update dummy_onedal.hpp
icfaust Sep 8, 2025
63d2869
Update dummy_onedal.hpp
icfaust Sep 8, 2025
ed9ec46
Update dummy_onedal.hpp
icfaust Sep 8, 2025
e986a08
Update dispatcher.py
icfaust Sep 8, 2025
b538b2d
Update setup.py
icfaust Sep 8, 2025
9d16f02
Update setup.py
icfaust Sep 8, 2025
10e59e5
Update prototypes.py
icfaust Sep 9, 2025
a0e0304
Update dal.cpp
icfaust Sep 9, 2025
dff4169
Update prototype.py
icfaust Sep 9, 2025
6c10169
Update dummy.cpp
icfaust Sep 9, 2025
9c151cd
local fixes
icfaust Sep 9, 2025
60d4324
Update prototypes.py
icfaust Sep 9, 2025
32a2ba8
Update __init__.py
icfaust Sep 9, 2025
89a8c68
Update prototypes.py
icfaust Sep 9, 2025
f1a746a
Update test_patching.py
icfaust Sep 10, 2025
2c08ffc
Update test_common.py
icfaust Sep 10, 2025
48af701
Update test_memory_usage.py
icfaust Sep 10, 2025
658a147
Update test_run_to_run_stability.py
icfaust Sep 10, 2025
5fc6e57
Update base.py
icfaust Sep 10, 2025
2836997
Update prototypes.py
icfaust Sep 10, 2025
9016add
Update prototypes.py
icfaust Sep 10, 2025
763b5d1
Update dummy_onedal.hpp
icfaust Sep 10, 2025
1cec67e
Update test_patching.py
icfaust Sep 10, 2025
b4d8e9e
Update prototypes.py
icfaust Sep 10, 2025
3b556f8
Update prototypes.py
icfaust Sep 10, 2025
ec796e8
Merge branch 'uxlfoundation:main' into dev/estimator_design_docs
icfaust Sep 10, 2025
05ebe08
Update test_n_jobs_support.py
icfaust Sep 10, 2025
22e4c6f
make requested changes
icfaust Sep 10, 2025
6f79554
add checks for SPMD
icfaust Sep 10, 2025
eaa6a71
Update dummy_onedal.hpp
icfaust Sep 11, 2025
3b423b4
Update prototypes.py
icfaust Sep 11, 2025
82fda39
Formattingg
icfaust Sep 11, 2025
71b71c2
Update prototypes.py
icfaust Sep 11, 2025
8d57f18
formatting
icfaust Sep 11, 2025
07488f4
remove patching
icfaust Sep 11, 2025
0c90896
Merge branch 'dev/estimator_design_docs' of https://github.com/icfaus…
icfaust Sep 11, 2025
9110c97
Revert "remove patching"
icfaust Sep 12, 2025
eabaf0e
fix pickling of GB algos
icfaust Sep 12, 2025
b719094
Update dummy_onedal.hpp
icfaust Sep 17, 2025
9e7c244
Update dummy.cpp
icfaust Sep 17, 2025
b99c17d
Revert "Update dummy.cpp"
icfaust Sep 17, 2025
79c314e
try to separate the array business
icfaust Sep 17, 2025
36bdfaf
try to separate the array business
icfaust Sep 17, 2025
593730e
Update select_sklearn_tests.py
icfaust Sep 18, 2025
efe4576
formatting sklearn1.0 fix
icfaust Sep 18, 2025
ad43b29
updates based on reviews
icfaust Sep 24, 2025
05ddcdb
more fixes
icfaust Sep 24, 2025
cab3a95
fixes
icfaust Sep 24, 2025
a3278b6
more fixes
icfaust Sep 24, 2025
c8b14ed
move structure
icfaust Sep 24, 2025
1e97974
formatting
icfaust Sep 24, 2025
942984c
make change to header
icfaust Sep 24, 2025
2ae9676
Apply suggestions from code review
icfaust Sep 24, 2025
e76339e
making commentary changes
icfaust Sep 24, 2025
fcc2ba2
Merge branch 'dev/estimator_design_docs' of https://github.com/icfaus…
icfaust Sep 24, 2025
6816ec7
Create __init__.py
icfaust Sep 24, 2025
ffff3c7
Update _dummy.py
icfaust Sep 24, 2025
6604e67
Update __init__.py
icfaust Sep 24, 2025
79174e6
add directory to setup.py
icfaust Sep 24, 2025
9ad56ee
add basic tests to dummy sklearnex estimator
icfaust Oct 1, 2025
8af7f77
fix more tests
icfaust Oct 1, 2025
2ce7f89
Update test_dummy.py
icfaust Oct 1, 2025
081296a
Update test_dummy.py
icfaust Oct 1, 2025
4414f9f
Update test_dummy.py
icfaust Oct 1, 2025
cf225b8
Update _dummy.py
icfaust Oct 1, 2025
6175373
try to fix new test
icfaust Oct 1, 2025
1214f0a
disable tests for non array API support
icfaust Oct 1, 2025
a29712b
Update test_dummy.py
icfaust Oct 2, 2025
2beafbf
Update test_svr.py
icfaust Oct 2, 2025
a332d9b
Update test_svc.py
icfaust Oct 2, 2025
e11da6d
Update test_nusvc.py
icfaust Oct 2, 2025
ccd55c3
Merge branch 'main' into dev/estimator_design_docs
icfaust Oct 13, 2025
e7fcb1a
make changes to reflect new design
icfaust Oct 13, 2025
73358ec
address last comments
icfaust Oct 14, 2025
d57ae03
address last comments
icfaust Oct 14, 2025
617bec1
grammar
icfaust Oct 14, 2025
8af27d1
grammar
icfaust Oct 14, 2025
c81571e
make recommended changes
icfaust Oct 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .ci/scripts/select_sklearn_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def parse_tests_tree(entry, prefix=""):
"model_selection/tests": ["test_split.py", "test_validation.py"],
"neighbors/tests": ["test_lof.py", "test_neighbors.py", "test_neighbors_pipeline.py"],
"svm/tests": ["test_sparse.py", "test_svm.py"],
"tests": "test_dummy.py",
}
if sklearn_check_version("1.2"):
tests_map["tests"] = ["test_public_functions.py"]
Expand Down
1 change: 1 addition & 0 deletions onedal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def __repr__(self) -> str:
"_spmd_backend",
"covariance",
"decomposition",
"dummy",
"ensemble",
"neighbors",
"primitives",
Expand Down
2 changes: 2 additions & 0 deletions onedal/dal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ ONEDAL_PY_INIT_MODULE(logistic_regression);
#if defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240700
ONEDAL_PY_INIT_MODULE(finiteness_checker);
#endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240700
ONEDAL_PY_INIT_MODULE(dummy);
#endif // ONEDAL_DATA_PARALLEL_SPMD

#ifdef ONEDAL_DATA_PARALLEL_SPMD
Expand Down Expand Up @@ -138,6 +139,7 @@ PYBIND11_MODULE(_onedal_py_host, m) {
#if defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240700
init_finiteness_checker(m);
#endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240700
init_dummy(m);
}
#endif // ONEDAL_DATA_PARALLEL_SPMD

Expand Down
19 changes: 19 additions & 0 deletions onedal/dummy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# ==============================================================================
# Copyright Contributors to the oneDAL Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from .dummy import DummyEstimator

__all__ = ["DummyEstimator"]
195 changes: 195 additions & 0 deletions onedal/dummy/dummy.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
/*******************************************************************************
* Copyright Contributors to the oneDAL Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the location of the files dummy.cpp, dummy_onedal.hpp, prototype.py correct?
Should those be located in onedal/dummy/ folder?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, having "dummy.py" instead of "prototype.py" might be more aligned with the rest of the codebase.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

#include "onedal/common.hpp"
#include "onedal/version.hpp"
// A fake oneDAL algorithm is include via the `dummy_onedal.hpp` header. In
// normal circumstances a header for the oneDAL algorithm would be
// included here from the oneDAL `oneapi/dal/algo/` folder.
#include "onedal/dummy/dummy_onedal.hpp"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment here specifying that in practice this would instead look like #include oneapi/dal/algo/... would be useful

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added!

#include "oneapi/dal/table/common.hpp"
#include "oneapi/dal/table/homogen.hpp"

namespace py = pybind11;

// oneDAL-python interaction code is located in this namespace
namespace oneapi::dal::python {

// pybind11 structures and functions of the 'dummy' algorithm
namespace dummy {

template <typename Task, typename Ops>
struct method2t {
method2t(const Task& task, const Ops& ops) : ops(ops) {}
// this functor converts the method param into a valid oneDAL task.
// Tasks are specific to each algorithm, therefore method2t is often
// defined for each algo.
template <typename Float>
auto operator()(const py::dict& params) {
using namespace dal::dummy;
const auto method = params["method"].cast<std::string>();

ONEDAL_PARAM_DISPATCH_VALUE(method, "dense", ops, Float, method::dense);
ONEDAL_PARAM_DISPATCH_VALUE(method, "by_default", ops, Float, method::by_default);
ONEDAL_PARAM_DISPATCH_THROW_INVALID_VALUE(method);
}

Ops ops;
};

struct params2desc {
// This functor converts the params dictionary into a oneDAL descriptor
template <typename Float, typename Method, typename Task>
auto operator()(const py::dict& params) {
auto desc = dal::dummy::descriptor<Float, Method, Task>();

// conversion of the params dict to oneDAL params occurs here except
// for the ``method`` and ``fptype`` parameters. They are assigned
// to the descriptor individually here before returning.
const auto constant = params["constant"].cast<double>();
desc.set_constant(constant);

return desc;
}
};

// the following functions define the python interface methods for the
// oneDAL algorithms. They are templated for the policy (which may be host,
// dpc, or spmd), and task, which is defined per algorithm. They are all
// defined using lambda functions (a common occurrence for pybind11), but
// that is not a requirement.
template <typename Policy, typename Task>
void init_train_ops(py::module& m) {
m.def("train", [](const Policy& policy, const py::dict& params, const table& data) {
using namespace dal::dummy;
using input_t = train_input<Task>;
// while there is a train_ops defined for each oneDAL algorithm
// which supports ``train``, this is the train_ops defined in
// onedal/common/dispatch_utils.hpp
train_ops ops(policy, input_t{ data }, params2desc{});
// fptype2t is defined in common/dispatch_utils.hpp
// which operates in a similar manner to the method2t functor
// it selects the floating point datatype for the calculation
return fptype2t{ method2t{ Task{}, ops } }(params);
});
};

template <typename Policy, typename Task>
void init_infer_ops(py::module_& m) {
m.def(
"infer",
[](const Policy& policy, const py::dict& params, const table& constant, const table& data) {
using namespace dal::dummy;
using input_t = infer_input<Task>;

infer_ops ops(policy, input_t{ data, constant }, params2desc{});
// with the use of functors the order of operations is as
// follows: Task is generated, the ops is already created above,
// method2t is constructed, and then fptype2t is constructed.
// It is then evaluated in opposite order sequentially on the
// params dict.
return fptype2t{ method2t{ Task{}, ops } }(params);
});
}

// This defines the result C++ objects for use in python via pybind11.
// Result object attributes should be pybind11 native types (like int,
// float, etc.) or oneDAL tables.

template <typename Task>
void init_train_result(py::module_& m) {
using namespace dal::dummy;
using result_t = train_result<Task>;

py::class_<result_t>(m, "train_result").def(py::init()).DEF_ONEDAL_PY_PROPERTY(data, result_t);
}

template <typename Task>
void init_infer_result(py::module_& m) {
using namespace dal::dummy;
using result_t = infer_result<Task>;

py::class_<result_t>(m, "infer_result").def(py::init()).DEF_ONEDAL_PY_PROPERTY(data, result_t);
}

ONEDAL_PY_DECLARE_INSTANTIATOR(init_train_result);
ONEDAL_PY_DECLARE_INSTANTIATOR(init_infer_result);
ONEDAL_PY_DECLARE_INSTANTIATOR(init_train_ops);
ONEDAL_PY_DECLARE_INSTANTIATOR(init_infer_ops);

} // namespace dummy

ONEDAL_PY_INIT_MODULE(dummy) {
using namespace dummy;
using namespace dal::detail;
using namespace dal::dummy;

// the task_list allows for multiple types of tasks (like regression
// and classification) template to be evaluated. The use of 'types'
// is not required, and has special implications for the
// 'bind_default_backend' function as it creates submodules in python
// based on the task name. See the covariance implementation
// where no task_list is used and a submodule of the algorithm is not
// made.
using task_list = types<task::generate>;
auto sub = m.def_submodule("dummy");

// explicitly define the templates based off of the policy and task
// lists. These instantiations lead to a cascade of fully-resolved
// templates from oneDAL. It begins by fully resolving functors defined
// here and the oneDAL descriptor. It then fully specifies functors in
// common/dispatch_utils.hpp, which starts resolving oneDAL objects
// for the algorithm like the train_ops/infer_ops functors defined there.
// This leads to a fair number of compile time work with oneDAL headers.
// For example take init_train_ops in approximate reverse order
// (to show how it goes from here to oneDAL):
//
// 0. Creates pybind11 interface
// 1. Specifies lambda defined in init_train_ops
// 2. Specifies fptype2t
// 3. Specifies method2t
// 4. Specifies train_ops defined in common/dispatch_utils.hpp
// 5. Specifies train defined in oneapi/dal/train.hpp
// 6. Specifies train_dispatch in oneapi/dal/detail/train_ops.hpp
// 7. Specifies several functors in oneapi/dal/detail/ops_dispatcher.hpp
// 8. Specifies train_ops defined in algorithm's train_ops.hpp
// 9. Specifies oneDAL train_input, train_result and descriptor structs
/**** finally hits objects compiled in oneDAL for the computation ****/
// (train_ops_dispatcher for example)
//
// Its not clear how many layers of these indirections are compiled
// versus optimized away. The namings in dispatch_utils.hpp are also
// unfortunate and confusing.

// policy_list is defined elsewhere which is dependent on the backend
// which is being built. Placed within a macro-check in order to prevent
// use with an spmd policy.
#ifndef ONEDAL_DATA_PARALLEL_SPMD
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about else (if spmd to be instantiated)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a comment

ONEDAL_PY_INSTANTIATE(init_train_ops, sub, policy_list, task_list);
ONEDAL_PY_INSTANTIATE(init_infer_ops, sub, policy_list, task_list);
ONEDAL_PY_INSTANTIATE(init_train_result, sub, task_list);
ONEDAL_PY_INSTANTIATE(init_infer_result, sub, task_list);
#else
// This is where the pybind11 init functions would be instantiated with
// a policy_spmd object. For example, if an init_train_ops existed for
// the spmd backend it would be instantiated like:
// ONEDAL_PY_INSTANTIATE(init_train_ops, sub, policy_spmd, task_list);
#endif
}

ONEDAL_PY_TYPE2STR(dal::dummy::task::generate, "generate");

} // namespace oneapi::dal::python
137 changes: 137 additions & 0 deletions onedal/dummy/dummy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# ==============================================================================
# Copyright Contributors to the oneDAL Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""This file describes necessary characteristics and design patterns of onedal estimators.

This can be used as a foundation for developing other estimators. Most
comments guiding code development should be removed unless pertinent to the
implementation."""

from .._device_offload import supports_queue
from ..common._backend import bind_default_backend
from ..datatypes import from_table, to_table


class DummyEstimator:
# This class creates a constant 2d array of specific size as an example

def __init__(self, constant=False):
# The __init__ method should only assign class attributes matching
# the input parameters (similar to sklearn). It is not to assign
# any attributes which aren't related to the operation of oneDAL.
# This is means that it should not conform to sklearn, only to
# oneDAL. Don't add unnecessary attributes which only match sklearn,
# these should be translated by the sklearnex estimator. In this case
# the only parameter for the dummy algorithm is the `constant` param.
self.constant = constant
self._onedal_model = None

# see documentation on bind_default_backend. There exists three possible
# oneDAL pybind11 interfaces, 'host', 'dpc' and 'spmd'. These are for
# cpu-only, cpu and gpu, and multi-device computation respectively. Logic
# in the onedal module will determine which can be used at import time.
# It will attempt to use the `dpc` interface if possible (which enables
# gpu computation) but requires a SYCL runtime. If not possible it will
# silently fall back to the 'host' pybind11 interface. The backend
# binding logic will seamlessly handle this for the estimator. The 'spmd'
# backend is specific to onedal estimators defined in the 'spmd' folder.
# The binding makes the pybind11 function a method of this class with
# the same name (in this case ProtoTypeEstimator.compute should call
# the pybind11 function onedal.backend.dummy.generate.train)
# where backend can be one of 'host', 'dpc' or 'spmd'.
@bind_default_backend("dummy.generate")
def train(self, params, data_table): ...

@bind_default_backend("dummy.generate")
def infer(self, params, model, data_table): ...

@supports_queue
def fit(self, X, y, queue=None):
# convert the data to oneDAL tables in preparation for use by the
# oneDAL pybind11 interfaces/objects.
X_t, y_t = to_table(X, y)

# Generating the params dict can be centralized into a class method,
# but it must be named ``_get_onedal_params``. Parameter 'fptype' is
# specific to the pybind11 interface, and cannot be found in oneDAL
# documentation. This tells oneDAL what float type to use for the
# computation. The safest and best way to assign this value is after
# the input data has been converted to a oneDAL table, as the dtype
# is standardized (taken care of by ``to_table``). This dtype is a
# ``numpy`` dtype due to its ubiquity and native support in pybind11.
params = {
"fptype": y_t.dtype, # normally X_t.dtype is used
"method": "dense",
"constant": self.constant,
}

# This is the call to the oneDAL pybind11 backend, which was
# previously bound using ``bind_default_backend``. It returns a
# pybind11 Python interface to the oneDAL C++ result object.
result = self.train(params, y_t)
# In general the naming conventions of ``fit`` match to ``train``,
# and ``predict`` match oneDAL's ``infer``. Please refer to the oneDAL
# design documentation to determine the best translation (headers
# under oneDAL/tree/main/cpp/oneapi/dal in the oneDAL repository,
# like for other correlaries like ``compute`` and ``partial_train``.
# Generally the sklearn naming scheme for class methods should be
# used here, but calls to the pybind11 interfaces should follow
# oneDAL naming.

# Oftentimes oneDAL table objects are attributes of the oneDAL C++
# object. These can be converted into various common data frameworks
# like ``numpy`` or ``dpctl.tensor`` using ``from_table``. In this
# case the output is a basic python type (bool) which can be handled
# easily just with pybind11 without any special code. Attributes of
# the result object are copied to attributes of the onedal estimator
# object.

self.constant_, self.fit_X_, self.fit_y_ = from_table(
result.data, X_t, y_t, like=X
)
# The fit_X_ and fit_y_ attributes are not required and are generally
# discouraged. They are set in order to show the process of setting
# and returning array values (and is just an example). In setting
# return attributes, post processing of the values beyond conversion
# needed for sklearn must occur in the sklearnex estimator.

def _create_model(self):
# While doing something rather trivial, this is closer to what may
# occur in other estimators which can generate models just in time.
# Necessary attributes are collected, converted to oneDAL tables
# and set to the oneDAL object. In general there should be a oneDAL
# model class defined with serialization and deserialization with a
# pybind11 interface.

# When the model is a oneDAL object (see svm), it must maintain a
# pybind11 interface to the `serialization` and `deserialization`
# oneDAL routines for proper pickling of the oneDAL object. oneDAL
# tables must be converted to the array type of the fitted estimator
# for pickling/unpickling (see any incremental estimator pybind11
# implementation).

# This example just treats a oneDAL table as the model.
return to_table(self.constant_)

@supports_queue
def predict(self, X, queue=None):
X_t = to_table(X)
if self._onedal_model is None:
self._onedal_model = self._create_model()

params = {"fptype": X_t.dtype, "method": "dense", "constant": self.constant}
result = self.infer(params, self._onedal_model, X_t)
return from_table(result.data, like=X)
Loading
Loading