From db38d4c686899c2d5233a952e82340ddb1b9467c Mon Sep 17 00:00:00 2001 From: David Cortes Date: Mon, 24 Nov 2025 16:17:25 +0100 Subject: [PATCH 1/9] update tsne for sklearn1.8 --- daal4py/sklearn/manifold/_t_sne.py | 187 +++++++++++++++-------------- doc/sources/algorithms.rst | 6 +- doc/sources/guide/acceleration.rst | 6 +- 3 files changed, 100 insertions(+), 99 deletions(-) diff --git a/daal4py/sklearn/manifold/_t_sne.py b/daal4py/sklearn/manifold/_t_sne.py index 66c78f9c3c..4d6fb83d76 100755 --- a/daal4py/sklearn/manifold/_t_sne.py +++ b/daal4py/sklearn/manifold/_t_sne.py @@ -122,36 +122,80 @@ def _daal_tsne(self, P, n_samples, X_embedded): return X_embedded + # Comment 2025-11-24: This appears to be a copy-paste from an earlier version of the original + # scikit-learn with some modifications to make calls to oneDAL under a narrow subset of + # allowed input parameters, copy-pastying the rest of the sklearn code when oneDAL is not + # called. Note that the conditions checked here are out of synch with the latest sklearn by now. + # An early 'is supported' check that offloads to stock sklearn was added later on, which results + # in having a lot of dead code paths in this function that can be safely removed. + # Note: this method is called from inside 'fit' from the base class in stock scikit-learn. + # Hence, the offloading logic is different than in other classes, as falling back to 'fit' + # from the base class would lead to a circular loop. def _fit(self, X, skip_num_points=0): """Private function to fit the model using X as training data.""" - if isinstance(self.init, str) and self.init == "warn": - warnings.warn( - "The default initialization in TSNE will change " - "from 'random' to 'pca' in 1.2.", - FutureWarning, - ) - self._init = "random" + + _patching_status = PatchingConditionsChain("sklearn.manifold.TSNE._tsne") + _patching_status.and_conditions( + [ + ( + self.method == "barnes_hut", + 'Used t-SNE method is not "barnes_hut" which is the only supported.', + ), + (self.n_components == 2, "Number of components != 2."), + (self.verbose == 0, "Verbose mode is set."), + ( + daal_check_version((2021, "P", 600)), + "oneDAL version is lower than 2021.6.", + ), + ( + not ( + isinstance(self.init, str) and self.init == "pca" and issparse(X) + ), + "PCA initialization is not supported with sparse input matrices.", + ), + # Note: these conditions below should result in errors, but stock scikit-learn + # does not check for errors at this exact point. Hence, this offloads the erroring + # out to the base class, wherever in the process they might be encountered. + ( + np.isscalar(self.angle) and self.angle > 0.0 and self.angle < 1.0, + "'angle' must be between 0.0 - 1.0", + ), + (self.early_exaggeration >= 1.0, "early_exaggeration must be at least 1"), + ( + ( + isinstance(self.init, str) + and self.init in ["random", "pca", "warn"] + ) + or isinstance(self.init, np.ndarray), + "'init' must be 'exact', 'pca', or a numpy array.", + ), + ] + ) + _dal_ready = _patching_status.get_status(logs=True) + if not _dal_ready: + return super()._fit(X, skip_num_points) + + if not sklearn_check_version("1.2"): + if isinstance(self.init, str) and self.init == "warn": + warnings.warn( + "The default initialization in TSNE will change " + "from 'random' to 'pca' in 1.2.", + FutureWarning, + ) + self._init = "random" + else: + self._init = self.init else: self._init = self.init - if isinstance(self._init, str) and self._init == "pca" and issparse(X): - raise TypeError( - "PCA initialization is currently not supported " - "with the sparse input matrix. Use " - 'init="random" instead.' - ) - - if self.method not in ["barnes_hut", "exact"]: - raise ValueError("'method' must be 'barnes_hut' or 'exact'") - if self.angle < 0.0 or self.angle > 1.0: - raise ValueError("'angle' must be between 0.0 - 1.0") - if self.learning_rate == "warn": - warnings.warn( - "The default learning rate in TSNE will change " - "from 200.0 to 'auto' in 1.2.", - FutureWarning, - ) - self._learning_rate = 200.0 + if not sklearn_check_version("1.2"): + if self.learning_rate == "warn": + warnings.warn( + "The default learning rate in TSNE will change " + "from 200.0 to 'auto' in 1.2.", + FutureWarning, + ) + self._learning_rate = 200.0 else: self._learning_rate = self.learning_rate if self._learning_rate == "auto": @@ -227,28 +271,15 @@ def _fit(self, X, skip_num_points=0): "or provide the dense distance matrix." ) - if self.method == "barnes_hut" and self.n_components > 3: - raise ValueError( - "'n_components' should be inferior to 4 for the " - "barnes_hut algorithm as it relies on " - "quad-tree or oct-tree." - ) random_state = check_random_state(self.random_state) - if self.early_exaggeration < 1.0: - raise ValueError( - "early_exaggeration must be at least 1, but is {}".format( - self.early_exaggeration - ) - ) - if not sklearn_check_version("1.2"): if self.n_iter < 250: raise ValueError("n_iter should be at least 250") n_samples = X.shape[0] - neighbors_nn = None + # neighbors_nn = None # <- unused variable in stock sklearn, commented out due to coverity if self.method == "exact": # Retrieve the distance matrix, either using the precomputed one or # computing it. @@ -278,10 +309,7 @@ def _fit(self, X, skip_num_points=0): "All distances should be positive, the " "metric given is not correct" ) - if ( - self.metric != "euclidean" - and getattr(self, "square_distances", True) is True - ): + if self.metric != "euclidean": distances **= 2 # compute the joint probability distribution for the input space @@ -339,16 +367,12 @@ def _fit(self, X, skip_num_points=0): # Free the memory used by the ball_tree del knn - if ( - getattr(self, "square_distances", True) is True - or self.metric == "euclidean" - ): - # knn return the euclidean distance but we need it squared - # to be consistent with the 'exact' method. Note that the - # the method was derived using the euclidean method as in the - # input space. Not sure of the implication of using a different - # metric. - distances_nn.data **= 2 + # knn return the euclidean distance but we need it squared + # to be consistent with the 'exact' method. Note that the + # the method was derived using the euclidean method as in the + # input space. Not sure of the implication of using a different + # metric. + distances_nn.data **= 2 # compute the joint probability distribution for the input space P = _joint_probabilities_nn(distances_nn, self.perplexity, self.verbose) @@ -358,16 +382,22 @@ def _fit(self, X, skip_num_points=0): elif self._init == "pca": pca = PCA( n_components=self.n_components, - svd_solver="randomized", random_state=random_state, ) + # Always output a numpy array, no matter what is configured globally + pca.set_output(transform="default") X_embedded = pca.fit_transform(X).astype(np.float32, copy=False) - warnings.warn( - "The PCA initialization in TSNE will change to " - "have the standard deviation of PC1 equal to 1e-4 " - "in 1.2. This will ensure better convergence.", - FutureWarning, - ) + if sklearn_check_version("1.1") and not sklearn_check_version("1.2"): + warnings.warn( + "The PCA initialization in TSNE will change to " + "have the standard deviation of PC1 equal to 1e-4 " + "in 1.2. This will ensure better convergence.", + FutureWarning, + ) + if sklearn_check_version("1.2"): + # PCA is rescaled so that PC1 has standard deviation 1e-4 which is + # the default value for random initialization. See issue #18018. + X_embedded = X_embedded / np.std(X_embedded[:, 0]) * 1e-4 elif self._init == "random": # The embedding is initialized with iid samples from Gaussians with # standard deviation 1e-4. @@ -377,40 +407,11 @@ def _fit(self, X, skip_num_points=0): else: raise ValueError("'init' must be 'pca', 'random', or " "a numpy array") - # Degrees of freedom of the Student's t-distribution. The suggestion - # degrees_of_freedom = n_components - 1 comes from - # "Learning a Parametric Embedding by Preserving Local Structure" - # Laurens van der Maaten, 2009. - degrees_of_freedom = max(self.n_components - 1, 1) + # Note: by this point, stock sklearn would calculate degrees of freedom, but oneDAL + # doesn't use them. - _patching_status = PatchingConditionsChain("sklearn.manifold.TSNE._tsne") - _patching_status.and_conditions( - [ - ( - self.method == "barnes_hut", - 'Used t-SNE method is not "barnes_hut" which is the only supported.', - ), - (self.n_components == 2, "Number of components != 2."), - (self.verbose == 0, "Verbose mode is set."), - ( - daal_check_version((2021, "P", 600)), - "oneDAL version is lower than 2021.6.", - ), - ] - ) - _dal_ready = _patching_status.get_status(logs=True) - - if _dal_ready: - X_embedded = check_array(X_embedded, dtype=[np.float32, np.float64]) - return self._daal_tsne(P, n_samples, X_embedded=X_embedded) - return self._tsne( - P, - degrees_of_freedom, - n_samples, - X_embedded=X_embedded, - neighbors=neighbors_nn, - skip_num_points=skip_num_points, - ) + X_embedded = check_array(X_embedded, dtype=[np.float32, np.float64]) + return self._daal_tsne(P, n_samples, X_embedded=X_embedded) fit.__doc__ = BaseTSNE.fit.__doc__ fit_transform.__doc__ = BaseTSNE.fit_transform.__doc__ diff --git a/doc/sources/algorithms.rst b/doc/sources/algorithms.rst index 192802daca..1f9eacfb3a 100755 --- a/doc/sources/algorithms.rst +++ b/doc/sources/algorithms.rst @@ -185,11 +185,11 @@ Dimensionality Reduction - All parameters are supported except: - ``metric`` != 'euclidean' or `'minkowski'` with ``p`` != `2` - - ``n_components`` can only be `2` - + - ``method`` != ``"barnes_hut"`` + Refer to :ref:`TSNE acceleration details ` to learn more. - - Sparse data is not supported + - Sparse data with ``init`` = ``"pca"`` is not supported Nearest Neighbors ***************** diff --git a/doc/sources/guide/acceleration.rst b/doc/sources/guide/acceleration.rst index ea368b4029..d93623856a 100644 --- a/doc/sources/guide/acceleration.rst +++ b/doc/sources/guide/acceleration.rst @@ -34,9 +34,9 @@ The overall acceleration of TSNE depends on the acceleration of each of these al - ``metric`` != `'euclidean'` or `'minkowski'` with ``p`` != `2` - The Gradient Descent part of the algorithm supports all parameters except: - - ``n_components`` = `3` - - ``method`` = `'exact'` - - ``verbose`` != `0` + - ``n_components`` > ``2`` + - ``method`` = ``'exact'`` + - ``verbose`` != ``0`` To get better performance, use parameters supported by both components. From 9c01a8cddf79c2e759f226312b336bf16389c19d Mon Sep 17 00:00:00 2001 From: David Cortes Date: Tue, 25 Nov 2025 09:52:44 +0100 Subject: [PATCH 2/9] fix test --- sklearnex/manifold/tests/test_tsne.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sklearnex/manifold/tests/test_tsne.py b/sklearnex/manifold/tests/test_tsne.py index 6c7e93f26b..98b332c31f 100755 --- a/sklearnex/manifold/tests/test_tsne.py +++ b/sklearnex/manifold/tests/test_tsne.py @@ -19,6 +19,8 @@ from numpy.testing import assert_allclose from sklearn.metrics.pairwise import pairwise_distances +from daal4py.sklearn._utils import sklearn_check_version + # Note: n_components must be 2 for now from onedal.tests.utils._dataframes_support import ( _as_numpy, @@ -161,8 +163,12 @@ def test_tsne_functionality_and_edge_cases( assert np.any(embedding != 0) +# Note: since sklearn1.2, the PCA initialization divides by standard deviations of components. +# Since those will be zeros for constant data, it will end up producing NaNs, hence it's not tested. @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues()) -@pytest.mark.parametrize("init", ["pca", "random"]) +@pytest.mark.parametrize( + "init", ["random"] + (["pca"] if not sklearn_check_version("1.2") else []) +) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) def test_tsne_constant_data(init, dataframe, queue, dtype): from sklearnex.manifold import TSNE From b3f852fd88559e13f7beefee23a12df4dcc284d4 Mon Sep 17 00:00:00 2001 From: David Cortes Date: Tue, 25 Nov 2025 09:53:11 +0100 Subject: [PATCH 3/9] more corrections --- daal4py/sklearn/manifold/_t_sne.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/daal4py/sklearn/manifold/_t_sne.py b/daal4py/sklearn/manifold/_t_sne.py index 4d6fb83d76..3fc54e8cbc 100755 --- a/daal4py/sklearn/manifold/_t_sne.py +++ b/daal4py/sklearn/manifold/_t_sne.py @@ -387,7 +387,7 @@ def _fit(self, X, skip_num_points=0): # Always output a numpy array, no matter what is configured globally pca.set_output(transform="default") X_embedded = pca.fit_transform(X).astype(np.float32, copy=False) - if sklearn_check_version("1.1") and not sklearn_check_version("1.2"): + if sklearn_check_version("1.0") and not sklearn_check_version("1.2"): warnings.warn( "The PCA initialization in TSNE will change to " "have the standard deviation of PC1 equal to 1e-4 " From 55a852999c7c3bc75eb52248b5be3920778c1d34 Mon Sep 17 00:00:00 2001 From: David Cortes Date: Tue, 25 Nov 2025 10:19:50 +0100 Subject: [PATCH 4/9] more fixes for older sklearn --- daal4py/sklearn/manifold/_t_sne.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/daal4py/sklearn/manifold/_t_sne.py b/daal4py/sklearn/manifold/_t_sne.py index 3fc54e8cbc..9dda26870d 100755 --- a/daal4py/sklearn/manifold/_t_sne.py +++ b/daal4py/sklearn/manifold/_t_sne.py @@ -384,8 +384,9 @@ def _fit(self, X, skip_num_points=0): n_components=self.n_components, random_state=random_state, ) - # Always output a numpy array, no matter what is configured globally - pca.set_output(transform="default") + if sklearn_check_version("1.2"): + # Always output a numpy array, no matter what is configured globally + pca.set_output(transform="default") X_embedded = pca.fit_transform(X).astype(np.float32, copy=False) if sklearn_check_version("1.0") and not sklearn_check_version("1.2"): warnings.warn( From d960a14a8988b08480946d68232c1c78e301b18e Mon Sep 17 00:00:00 2001 From: David Cortes Date: Tue, 25 Nov 2025 16:31:35 +0100 Subject: [PATCH 5/9] missing else --- daal4py/sklearn/manifold/_t_sne.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/daal4py/sklearn/manifold/_t_sne.py b/daal4py/sklearn/manifold/_t_sne.py index 9dda26870d..bb6474d068 100755 --- a/daal4py/sklearn/manifold/_t_sne.py +++ b/daal4py/sklearn/manifold/_t_sne.py @@ -196,6 +196,8 @@ def _fit(self, X, skip_num_points=0): FutureWarning, ) self._learning_rate = 200.0 + else: + self._learning_rate = self.learning_rate else: self._learning_rate = self.learning_rate if self._learning_rate == "auto": From 8ebab02b68f7fc52bcae734e5566ca9965c73a61 Mon Sep 17 00:00:00 2001 From: David Cortes Date: Tue, 25 Nov 2025 16:50:44 +0100 Subject: [PATCH 6/9] more fixes for older sklearn --- daal4py/sklearn/manifold/_t_sne.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/daal4py/sklearn/manifold/_t_sne.py b/daal4py/sklearn/manifold/_t_sne.py index bb6474d068..293ff7d183 100755 --- a/daal4py/sklearn/manifold/_t_sne.py +++ b/daal4py/sklearn/manifold/_t_sne.py @@ -276,6 +276,12 @@ def _fit(self, X, skip_num_points=0): random_state = check_random_state(self.random_state) if not sklearn_check_version("1.2"): + if self.early_exaggeration < 1.0: + raise ValueError( + "early_exaggeration must be at least 1, but is {}".format( + self.early_exaggeration + ) + ) if self.n_iter < 250: raise ValueError("n_iter should be at least 250") @@ -311,7 +317,9 @@ def _fit(self, X, skip_num_points=0): "All distances should be positive, the " "metric given is not correct" ) - if self.metric != "euclidean": + if self.metric != "euclidean" and ( + sklearn_check_version("1.2") or self.square_distances is True + ): distances **= 2 # compute the joint probability distribution for the input space @@ -374,7 +382,10 @@ def _fit(self, X, skip_num_points=0): # the method was derived using the euclidean method as in the # input space. Not sure of the implication of using a different # metric. - distances_nn.data **= 2 + if sklearn_check_version("1.2") or ( + self.metric != "euclidean" and self.square_distances is True + ): + distances_nn.data **= 2 # compute the joint probability distribution for the input space P = _joint_probabilities_nn(distances_nn, self.perplexity, self.verbose) From 42d8d50fd398cd252d73591a409a6a7289c405ce Mon Sep 17 00:00:00 2001 From: David Cortes Date: Tue, 25 Nov 2025 16:55:08 +0100 Subject: [PATCH 7/9] correction --- daal4py/sklearn/manifold/_t_sne.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/daal4py/sklearn/manifold/_t_sne.py b/daal4py/sklearn/manifold/_t_sne.py index 293ff7d183..0ae543ffe8 100755 --- a/daal4py/sklearn/manifold/_t_sne.py +++ b/daal4py/sklearn/manifold/_t_sne.py @@ -383,7 +383,7 @@ def _fit(self, X, skip_num_points=0): # input space. Not sure of the implication of using a different # metric. if sklearn_check_version("1.2") or ( - self.metric != "euclidean" and self.square_distances is True + self.square_distances is True or self.metric == "euclidean" ): distances_nn.data **= 2 From 13ce36455640386a0f9082e3b4b91de701c9e1c7 Mon Sep 17 00:00:00 2001 From: David Cortes Date: Tue, 25 Nov 2025 17:27:15 +0100 Subject: [PATCH 8/9] remove redundant check --- daal4py/sklearn/manifold/_t_sne.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/daal4py/sklearn/manifold/_t_sne.py b/daal4py/sklearn/manifold/_t_sne.py index 0ae543ffe8..4fb5cb5fdf 100755 --- a/daal4py/sklearn/manifold/_t_sne.py +++ b/daal4py/sklearn/manifold/_t_sne.py @@ -276,12 +276,6 @@ def _fit(self, X, skip_num_points=0): random_state = check_random_state(self.random_state) if not sklearn_check_version("1.2"): - if self.early_exaggeration < 1.0: - raise ValueError( - "early_exaggeration must be at least 1, but is {}".format( - self.early_exaggeration - ) - ) if self.n_iter < 250: raise ValueError("n_iter should be at least 250") From 2cbeeb9d78fbf78f21de53a509d4dd7cd2ebb970 Mon Sep 17 00:00:00 2001 From: David Cortes Date: Wed, 26 Nov 2025 09:39:44 +0100 Subject: [PATCH 9/9] more clear conditions --- daal4py/sklearn/manifold/_t_sne.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/daal4py/sklearn/manifold/_t_sne.py b/daal4py/sklearn/manifold/_t_sne.py index 4fb5cb5fdf..88d964e1fa 100755 --- a/daal4py/sklearn/manifold/_t_sne.py +++ b/daal4py/sklearn/manifold/_t_sne.py @@ -164,7 +164,14 @@ def _fit(self, X, skip_num_points=0): ( ( isinstance(self.init, str) - and self.init in ["random", "pca", "warn"] + and self.init + in ["random", "pca"] + + ( + ["warn"] + if sklearn_check_version("1.0") + and not sklearn_check_version("1.2") + else [] + ) ) or isinstance(self.init, np.ndarray), "'init' must be 'exact', 'pca', or a numpy array.", @@ -175,7 +182,7 @@ def _fit(self, X, skip_num_points=0): if not _dal_ready: return super()._fit(X, skip_num_points) - if not sklearn_check_version("1.2"): + if sklearn_check_version("1.0") and not sklearn_check_version("1.2"): if isinstance(self.init, str) and self.init == "warn": warnings.warn( "The default initialization in TSNE will change " @@ -188,7 +195,7 @@ def _fit(self, X, skip_num_points=0): else: self._init = self.init - if not sklearn_check_version("1.2"): + if sklearn_check_version("1.0") and not sklearn_check_version("1.2"): if self.learning_rate == "warn": warnings.warn( "The default learning rate in TSNE will change "