28
28
_LOG = logging .getLogger (__name__ )
29
29
30
30
31
+ def calculate_dcrs (data : np .ndarray | None , query : np .ndarray | None ) -> np .ndarray | None :
32
+ """
33
+ Calculate Distance to Closest Records (DCRs).
34
+
35
+ Args:
36
+ data: Embeddings of the training data.
37
+ query: Embeddings of the query set.
38
+
39
+ Returns:
40
+ """
41
+ if data is None or query is None :
42
+ return None
43
+ # sort data by first dimension to enforce deterministic results
44
+ data = data [data [:, 0 ].argsort ()]
45
+ _LOG .info (f"calculate DCRs for { data .shape = } and { query .shape = } " )
46
+ index = NearestNeighbors (n_neighbors = 1 , algorithm = "auto" , metric = "cosine" , n_jobs = min (cpu_count () - 1 , 16 ))
47
+ index .fit (data )
48
+ dcrs , _ = index .kneighbors (query )
49
+ return dcrs [:, 0 ]
50
+
51
+
31
52
def calculate_distances (
32
53
* , syn_embeds : np .ndarray , trn_embeds : np .ndarray , hol_embeds : np .ndarray | None
33
54
) -> tuple [np .ndarray , np .ndarray | None , np .ndarray | None ]:
@@ -47,28 +68,13 @@ def calculate_distances(
47
68
"""
48
69
if hol_embeds is not None :
49
70
assert trn_embeds .shape == hol_embeds .shape
50
- # calculate DCR for synthetic to training
51
- index_syn = NearestNeighbors (n_neighbors = 1 , algorithm = "brute" , metric = "l2" , n_jobs = min (cpu_count () - 1 , 16 ))
52
- index_syn .fit (syn_embeds )
53
- _LOG .info (f"calculate DCRs for { len (syn_embeds ):,} synthetic to { len (trn_embeds ):,} training" )
54
- dcrs_syn_trn , _ = index_syn .kneighbors (trn_embeds )
55
- dcr_syn_trn = dcrs_syn_trn [:, 0 ]
56
71
57
- dcr_syn_hol = None
58
- dcr_trn_hol = None
59
-
60
- if hol_embeds is not None :
61
- # calculate DCR for synthetic to holdout
62
- _LOG .info (f"calculate DCRs for { len (syn_embeds ):,} synthetic to { len (hol_embeds ):,} holdout" )
63
- dcrs_syn_hol , _ = index_syn .kneighbors (hol_embeds )
64
- dcr_syn_hol = dcrs_syn_hol [:, 0 ]
65
-
66
- # calculate DCR for training to holdout
67
- _LOG .info (f"calculate DCRs for { len (trn_embeds ):,} training to { len (hol_embeds ):,} holdout" )
68
- index_trn = NearestNeighbors (n_neighbors = 1 , algorithm = "brute" , metric = "l2" , n_jobs = min (cpu_count () - 1 , 16 ))
69
- index_trn .fit (trn_embeds )
70
- dcrs_trn_hol , _ = index_trn .kneighbors (hol_embeds )
71
- dcr_trn_hol = dcrs_trn_hol [:, 0 ]
72
+ # calculate DCR for synthetic to training
73
+ dcr_syn_trn = calculate_dcrs (data = trn_embeds , query = syn_embeds )
74
+ # calculate DCR for synthetic to holdout
75
+ dcr_syn_hol = calculate_dcrs (data = hol_embeds , query = syn_embeds )
76
+ # calculate DCR for holdout to training
77
+ dcr_trn_hol = calculate_dcrs (data = trn_embeds , query = hol_embeds )
72
78
73
79
dcr_syn_trn_deciles = np .round (np .quantile (dcr_syn_trn , np .linspace (0 , 1 , 11 )), 3 )
74
80
_LOG .info (f"DCR deciles for synthetic to training: { dcr_syn_trn_deciles } " )
0 commit comments