|
63 | 63 | "cosine": 2,
|
64 | 64 | "hellinger": 1,
|
65 | 65 | "jaccard": 1,
|
| 66 | + "bit_jaccard": 1, |
66 | 67 | "dice": 1,
|
67 | 68 | }
|
68 | 69 |
|
@@ -2351,8 +2352,10 @@ def fit(self, X, y=None, force_all_finite=True):
|
2351 | 2352 | - 'allow-nan': accepts only np.nan and pd.NA values in array.
|
2352 | 2353 | Values cannot be infinite.
|
2353 | 2354 | """
|
2354 |
| - |
2355 |
| - X = check_array(X, dtype=np.float32, accept_sparse="csr", order="C", force_all_finite=force_all_finite) |
| 2355 | + if self.metric in ("bit_hamming", "bit_jaccard"): |
| 2356 | + X = check_array(X, dtype=np.uint8, order="C", force_all_finite=force_all_finite) |
| 2357 | + else: |
| 2358 | + X = check_array(X, dtype=np.float32, accept_sparse="csr", order="C", force_all_finite=force_all_finite) |
2356 | 2359 | self._raw_data = X
|
2357 | 2360 |
|
2358 | 2361 | # Handle all the optional arguments, setting default
|
@@ -2926,7 +2929,10 @@ def transform(self, X, force_all_finite=True):
|
2926 | 2929 | "Transform unavailable when model was fit with only a single data sample."
|
2927 | 2930 | )
|
2928 | 2931 | # If we just have the original input then short circuit things
|
2929 |
| - X = check_array(X, dtype=np.float32, accept_sparse="csr", order="C", force_all_finite=force_all_finite) |
| 2932 | + if self.metric in ("bit_hamming", "bit_jaccard"): |
| 2933 | + X = check_array(X, dtype=np.uint8, order="C", force_all_finite=force_all_finite) |
| 2934 | + else: |
| 2935 | + X = check_array(X, dtype=np.float32, accept_sparse="csr", order="C", force_all_finite=force_all_finite) |
2930 | 2936 | x_hash = joblib.hash(X)
|
2931 | 2937 | if x_hash == self._input_hash:
|
2932 | 2938 | if self.transform_mode == "embedding":
|
@@ -3297,7 +3303,10 @@ def _output_dist_only(x, y, *kwds):
|
3297 | 3303 | return inv_transformed_points
|
3298 | 3304 |
|
3299 | 3305 | def update(self, X, force_all_finite=True):
|
3300 |
| - X = check_array(X, dtype=np.float32, accept_sparse="csr", order="C", force_all_finite=force_all_finite) |
| 3306 | + if self.metric in ("bit_hamming", "bit_jaccard"): |
| 3307 | + X = check_array(X, dtype=np.uint8, order="C", force_all_finite=force_all_finite) |
| 3308 | + else: |
| 3309 | + X = check_array(X, dtype=np.float32, accept_sparse="csr", order="C", force_all_finite=force_all_finite) |
3301 | 3310 | random_state = check_random_state(self.transform_seed)
|
3302 | 3311 | rng_state = random_state.randint(INT32_MIN, INT32_MAX, 3).astype(np.int64)
|
3303 | 3312 |
|
|
0 commit comments