Skip to content

Commit 7b4f3e8

Browse files
authored
feat: warn if more than 50k embeddings are calculated (#202)
1 parent a13a117 commit 7b4f3e8

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

mostlyai/qa/_sampling.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import logging
2929
import random
3030
from typing import Any
31+
import warnings
3132
from pandas.core.dtypes.common import is_numeric_dtype, is_datetime64_dtype
3233

3334
import numpy as np
@@ -316,15 +317,24 @@ def prepare_data_for_embeddings(
316317
hol_tgt_data = hol_tgt_data.drop(columns=[key]) if hol else None
317318

318319
# draw equally sized samples for fair 3-way comparison
319-
max_sample_size = min(
320+
max_sample_size_final = min(
320321
max_sample_size or float("inf"),
321322
len(syn_tgt_data),
322323
len(trn_tgt_data),
323324
len(hol_tgt_data) if hol_tgt_data is not None else float("inf"),
324325
)
325-
syn_tgt_data = syn_tgt_data.sample(n=max_sample_size)
326-
trn_tgt_data = trn_tgt_data.sample(n=max_sample_size)
327-
hol_tgt_data = hol_tgt_data.sample(n=max_sample_size) if hol else None
326+
syn_tgt_data = syn_tgt_data.sample(n=max_sample_size_final)
327+
trn_tgt_data = trn_tgt_data.sample(n=max_sample_size_final)
328+
hol_tgt_data = hol_tgt_data.sample(n=max_sample_size_final) if hol else None
329+
330+
if max_sample_size_final > 50_000 and max_sample_size is None:
331+
warnings.warn(
332+
UserWarning(
333+
"More than 50k embeddings will be calculated per dataset, which may take a long time. "
334+
"Consider setting a limit via `max_sample_size_embeddings` to speed up the process. "
335+
"Note however, that limiting the number of embeddings will affect the sensitivity of the distance metrics."
336+
)
337+
)
328338

329339
# limit to same columns
330340
trn_cols = list(trn_tgt_data.columns)[:EMBEDDINGS_MAX_COLUMNS]

0 commit comments

Comments
 (0)