File tree 1 file changed +14
-4
lines changed 1 file changed +14
-4
lines changed Original file line number Diff line number Diff line change 28
28
import logging
29
29
import random
30
30
from typing import Any
31
+ import warnings
31
32
from pandas .core .dtypes .common import is_numeric_dtype , is_datetime64_dtype
32
33
33
34
import numpy as np
@@ -316,15 +317,24 @@ def prepare_data_for_embeddings(
316
317
hol_tgt_data = hol_tgt_data .drop (columns = [key ]) if hol else None
317
318
318
319
# draw equally sized samples for fair 3-way comparison
319
- max_sample_size = min (
320
+ max_sample_size_final = min (
320
321
max_sample_size or float ("inf" ),
321
322
len (syn_tgt_data ),
322
323
len (trn_tgt_data ),
323
324
len (hol_tgt_data ) if hol_tgt_data is not None else float ("inf" ),
324
325
)
325
- syn_tgt_data = syn_tgt_data .sample (n = max_sample_size )
326
- trn_tgt_data = trn_tgt_data .sample (n = max_sample_size )
327
- hol_tgt_data = hol_tgt_data .sample (n = max_sample_size ) if hol else None
326
+ syn_tgt_data = syn_tgt_data .sample (n = max_sample_size_final )
327
+ trn_tgt_data = trn_tgt_data .sample (n = max_sample_size_final )
328
+ hol_tgt_data = hol_tgt_data .sample (n = max_sample_size_final ) if hol else None
329
+
330
+ if max_sample_size_final > 50_000 and max_sample_size is None :
331
+ warnings .warn (
332
+ UserWarning (
333
+ "More than 50k embeddings will be calculated per dataset, which may take a long time. "
334
+ "Consider setting a limit via `max_sample_size_embeddings` to speed up the process. "
335
+ "Note however, that limiting the number of embeddings will affect the sensitivity of the distance metrics."
336
+ )
337
+ )
328
338
329
339
# limit to same columns
330
340
trn_cols = list (trn_tgt_data .columns )[:EMBEDDINGS_MAX_COLUMNS ]
You can’t perform that action at this time.
0 commit comments