Skip to content

Commit df6fec7

Browse files
authored
feat: pair-wise DCR share; use cosine; make deterministic (#167)
1 parent 5685a85 commit df6fec7

File tree

2 files changed

+29
-22
lines changed

2 files changed

+29
-22
lines changed

mostlyai/qa/_distances.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,27 @@
2828
_LOG = logging.getLogger(__name__)
2929

3030

31+
def calculate_dcrs(data: np.ndarray | None, query: np.ndarray | None) -> np.ndarray | None:
32+
"""
33+
Calculate Distance to Closest Records (DCRs).
34+
35+
Args:
36+
data: Embeddings of the training data.
37+
query: Embeddings of the query set.
38+
39+
Returns:
40+
"""
41+
if data is None or query is None:
42+
return None
43+
# sort data by first dimension to enforce deterministic results
44+
data = data[data[:, 0].argsort()]
45+
_LOG.info(f"calculate DCRs for {data.shape=} and {query.shape=}")
46+
index = NearestNeighbors(n_neighbors=1, algorithm="auto", metric="cosine", n_jobs=min(cpu_count() - 1, 16))
47+
index.fit(data)
48+
dcrs, _ = index.kneighbors(query)
49+
return dcrs[:, 0]
50+
51+
3152
def calculate_distances(
3253
*, syn_embeds: np.ndarray, trn_embeds: np.ndarray, hol_embeds: np.ndarray | None
3354
) -> tuple[np.ndarray, np.ndarray | None, np.ndarray | None]:
@@ -47,28 +68,13 @@ def calculate_distances(
4768
"""
4869
if hol_embeds is not None:
4970
assert trn_embeds.shape == hol_embeds.shape
50-
# calculate DCR for synthetic to training
51-
index_syn = NearestNeighbors(n_neighbors=1, algorithm="brute", metric="l2", n_jobs=min(cpu_count() - 1, 16))
52-
index_syn.fit(syn_embeds)
53-
_LOG.info(f"calculate DCRs for {len(syn_embeds):,} synthetic to {len(trn_embeds):,} training")
54-
dcrs_syn_trn, _ = index_syn.kneighbors(trn_embeds)
55-
dcr_syn_trn = dcrs_syn_trn[:, 0]
5671

57-
dcr_syn_hol = None
58-
dcr_trn_hol = None
59-
60-
if hol_embeds is not None:
61-
# calculate DCR for synthetic to holdout
62-
_LOG.info(f"calculate DCRs for {len(syn_embeds):,} synthetic to {len(hol_embeds):,} holdout")
63-
dcrs_syn_hol, _ = index_syn.kneighbors(hol_embeds)
64-
dcr_syn_hol = dcrs_syn_hol[:, 0]
65-
66-
# calculate DCR for training to holdout
67-
_LOG.info(f"calculate DCRs for {len(trn_embeds):,} training to {len(hol_embeds):,} holdout")
68-
index_trn = NearestNeighbors(n_neighbors=1, algorithm="brute", metric="l2", n_jobs=min(cpu_count() - 1, 16))
69-
index_trn.fit(trn_embeds)
70-
dcrs_trn_hol, _ = index_trn.kneighbors(hol_embeds)
71-
dcr_trn_hol = dcrs_trn_hol[:, 0]
72+
# calculate DCR for synthetic to training
73+
dcr_syn_trn = calculate_dcrs(data=trn_embeds, query=syn_embeds)
74+
# calculate DCR for synthetic to holdout
75+
dcr_syn_hol = calculate_dcrs(data=hol_embeds, query=syn_embeds)
76+
# calculate DCR for holdout to training
77+
dcr_trn_hol = calculate_dcrs(data=trn_embeds, query=hol_embeds)
7278

7379
dcr_syn_trn_deciles = np.round(np.quantile(dcr_syn_trn, np.linspace(0, 1, 11)), 3)
7480
_LOG.info(f"DCR deciles for synthetic to training: {dcr_syn_trn_deciles}")

tests/unit/test_distances.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import numpy as np
1516
import pandas as pd
1617
import pytest
1718

@@ -54,7 +55,7 @@ def test_calculate_distances():
5455
assert len(dcr_syn_hol) == n
5556
assert len(dcr_trn_hol) == n
5657
assert dcr_syn_trn.min() > 0
57-
assert dcr_syn_hol.max() == 0
58+
assert np.isclose(dcr_syn_hol.max(), 0, atol=1e-6)
5859

5960
# test specifically that near matches do not report a distance of 0 due to rounding
6061
syn_embeds = calculate_embeddings(["a 0.0002"] * n, embedder=embedder)

0 commit comments

Comments
 (0)