Skip to content

[FEA] Support for Dice coefficient as a metric in UMAP #5129

@beckernick

Description

@beckernick

I'd like to be able to use the Dice coefficient for UMAP like I can on the CPU. A small number of users choose the Dice metric based on this Github search.

But, as RAFT supports the Dice coefficient as a metric and cuML's DistanceType enum already supports DiceExpanded, this may be as simple as adding it a supported metric in the UMAP metric mapping dictionary.

import cuml
import umap

X, _ = cuml.datasets.make_blobs()

reducer = umap.UMAP(metric="dice")
print(reducer.fit_transform(X.get())[:5])
/home/nicholasb/miniconda3/envs/rapids-23.02/lib/python3.8/site-packages/umap/umap_.py:1802: UserWarning: gradient function is not yet implemented for dice distance metric; inverse_transform will be unavailable
  warn(
[[-1.687048   15.136897  ]
 [-1.1637341  14.760551  ]
 [ 1.4106333  14.12161   ]
 [-0.07474275 13.303681  ]
 [ 1.5589024  16.217863  ]]

reducer = cuml.manifold.umap.UMAP(metric="dice")
print(reducer.fit_transform(X)[:5])

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [14], in <cell line: 10>()
      7 print(reducer.fit_transform(X.get())[:5])
      9 reducer = cuml.manifold.umap.UMAP(metric="dice")
---> 10 print(reducer.fit_transform(X)[:5])

File ~/miniconda3/envs/rapids-23.02/lib/python3.8/site-packages/cuml/internals/api_decorators.py:548, in BaseReturnArrayDecorator.__call__.<locals>.inner_set_get(*args, **kwargs)
    545         self.do_getters_with_self_no_input(self_val=self_val)
    547     # Call the function
--> 548     ret_val = func(*args, **kwargs)
    550 return cm.process_return(ret_val)

File ~/miniconda3/envs/rapids-23.02/lib/python3.8/site-packages/cuml/internals/api_decorators.py:817, in enable_device_interop.<locals>.dispatch(self, *args, **kwargs)
    815 if hasattr(self, 'dispatch_func'):
    816     func_name = gpu_func.__name__
--> 817     return self.dispatch_func(func_name, gpu_func, *args, **kwargs)
    818 else:
    819     return gpu_func(self, *args, **kwargs)

File ~/miniconda3/envs/rapids-23.02/lib/python3.8/site-packages/cuml/internals/api_decorators.py:359, in ReturnAnyDecorator.__call__.<locals>.inner(*args, **kwargs)
    356 @wraps(func)
    357 def inner(*args, **kwargs):
    358     with self._recreate_cm(func, args):
--> 359         return func(*args, **kwargs)

File base.pyx:656, in cuml.internals.base.UniversalBase.dispatch_func()

File umap.pyx:659, in cuml.manifold.umap.UMAP.fit_transform()

File ~/miniconda3/envs/rapids-23.02/lib/python3.8/site-packages/cuml/internals/api_decorators.py:408, in BaseReturnAnyDecorator.__call__.<locals>.inner_with_setters(*args, **kwargs)
    401 self_val, input_val, target_val = \
    402     self.get_arg_values(*args, **kwargs)
    404 self.do_setters(self_val=self_val,
    405                 input_val=input_val,
    406                 target_val=target_val)
--> 408 return func(*args, **kwargs)

File ~/miniconda3/envs/rapids-23.02/lib/python3.8/site-packages/cuml/internals/api_decorators.py:817, in enable_device_interop.<locals>.dispatch(self, *args, **kwargs)
    815 if hasattr(self, 'dispatch_func'):
    816     func_name = gpu_func.__name__
--> 817     return self.dispatch_func(func_name, gpu_func, *args, **kwargs)
    818 else:
    819     return gpu_func(self, *args, **kwargs)

File ~/miniconda3/envs/rapids-23.02/lib/python3.8/site-packages/cuml/internals/api_decorators.py:359, in ReturnAnyDecorator.__call__.<locals>.inner(*args, **kwargs)
    356 @wraps(func)
    357 def inner(*args, **kwargs):
    358     with self._recreate_cm(func, args):
--> 359         return func(*args, **kwargs)

File base.pyx:656, in cuml.internals.base.UniversalBase.dispatch_func()

File umap.pyx:569, in cuml.manifold.umap.UMAP.fit()

File umap.pyx:465, in cuml.manifold.umap.UMAP._build_umap_params()

ValueError: Invalid value for metric: dice

(Using the 23.02.00a230112 cuda11_py38_g69db20fd6_77 nightly conda package).

Metadata

Metadata

Assignees

Labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions