28
28
_LOG = logging .getLogger (__name__ )
29
29
30
30
31
- def calculate_dcrs (data : np .ndarray | None , query : np .ndarray | None ) -> np .ndarray | None :
31
+ def calculate_dcrs_nndrs (
32
+ data : np .ndarray | None , query : np .ndarray | None
33
+ ) -> tuple [np .ndarray | None , np .ndarray | None ]:
32
34
"""
33
- Calculate Distance to Closest Records (DCRs).
35
+ Calculate Distance to Closest Records (DCRs) and Nearest Neighbor Distance Ratios (NNDRs) .
34
36
35
37
Args:
36
38
data: Embeddings of the training data.
@@ -39,19 +41,21 @@ def calculate_dcrs(data: np.ndarray | None, query: np.ndarray | None) -> np.ndar
39
41
Returns:
40
42
"""
41
43
if data is None or query is None :
42
- return None
44
+ return None , None
43
45
# sort data by first dimension to enforce deterministic results
44
46
data = data [data [:, 0 ].argsort ()]
45
47
_LOG .info (f"calculate DCRs for { data .shape = } and { query .shape = } " )
46
- index = NearestNeighbors (n_neighbors = 1 , algorithm = "auto" , metric = "cosine" , n_jobs = min (cpu_count () - 1 , 16 ))
48
+ index = NearestNeighbors (n_neighbors = 2 , algorithm = "auto" , metric = "cosine" , n_jobs = min (cpu_count () - 1 , 16 ))
47
49
index .fit (data )
48
50
dcrs , _ = index .kneighbors (query )
49
- return dcrs [:, 0 ]
51
+ dcr = dcrs [:, 0 ]
52
+ nndr = (dcrs [:, 0 ] + 1e-8 ) / (dcrs [:, 1 ] + 1e-8 )
53
+ return dcr , nndr
50
54
51
55
52
56
def calculate_distances (
53
57
* , syn_embeds : np .ndarray , trn_embeds : np .ndarray , hol_embeds : np .ndarray | None
54
- ) -> tuple [ np . ndarray , np .ndarray | None , np . ndarray | None ]:
58
+ ) -> dict [ str , np .ndarray ]:
55
59
"""
56
60
Calculates distances to the closest records (DCR).
57
61
@@ -61,52 +65,96 @@ def calculate_distances(
61
65
hol_embeds: Embeddings of holdout data.
62
66
63
67
Returns:
64
- Tuple containing:
68
+ Dictionary containing:
65
69
- dcr_syn_trn: DCR for synthetic to training.
66
70
- dcr_syn_hol: DCR for synthetic to holdout.
67
71
- dcr_trn_hol: DCR for training to holdout.
72
+ - nndr_syn_trn: NNDR for synthetic to training.
73
+ - nndr_syn_hol: NNDR for synthetic to holdout.
74
+ - nndr_trn_hol: NNDR for training to holdout.
68
75
"""
69
76
if hol_embeds is not None :
70
77
assert trn_embeds .shape == hol_embeds .shape
71
78
72
- # calculate DCR for synthetic to training
73
- dcr_syn_trn = calculate_dcrs (data = trn_embeds , query = syn_embeds )
74
- # calculate DCR for synthetic to holdout
75
- dcr_syn_hol = calculate_dcrs (data = hol_embeds , query = syn_embeds )
76
- # calculate DCR for holdout to training
77
- dcr_trn_hol = calculate_dcrs (data = trn_embeds , query = hol_embeds )
79
+ # calculate DCR / NNDR for synthetic to training
80
+ dcr_syn_trn , nndr_syn_trn = calculate_dcrs_nndrs (data = trn_embeds , query = syn_embeds )
81
+ # calculate DCR / NNDR for synthetic to holdout
82
+ dcr_syn_hol , nndr_syn_hol = calculate_dcrs_nndrs (data = hol_embeds , query = syn_embeds )
83
+ # calculate DCR / NNDR for holdout to training
84
+ dcr_trn_hol , nndr_trn_hol = calculate_dcrs_nndrs (data = trn_embeds , query = hol_embeds )
78
85
79
- dcr_syn_trn_deciles = np .round (np .quantile (dcr_syn_trn , np .linspace (0 , 1 , 11 )), 3 )
80
- _LOG .info (f"DCR deciles for synthetic to training: { dcr_syn_trn_deciles } " )
86
+ # log statistics
87
+ def deciles (x ):
88
+ return np .round (np .quantile (x , np .linspace (0 , 1 , 11 )), 3 )
89
+
90
+ _LOG .info (f"DCR deciles for synthetic to training: { deciles (dcr_syn_trn )} " )
91
+ _LOG .info (f"NNDR deciles for synthetic to training: { deciles (nndr_syn_trn )} " )
81
92
if dcr_syn_hol is not None :
82
- dcr_syn_hol_deciles = np .round (np .quantile (dcr_syn_hol , np .linspace (0 , 1 , 11 )), 3 )
83
- _LOG .info (f"DCR deciles for synthetic to holdout: { dcr_syn_hol_deciles } " )
84
- # calculate share of dcr_syn_trn != dcr_syn_hol
93
+ _LOG .info (f"DCR deciles for synthetic to holdout: { deciles (dcr_syn_hol )} " )
94
+ _LOG .info (f"NNDR deciles for synthetic to holdout: { deciles (nndr_syn_hol )} " )
85
95
_LOG .info (f"share of dcr_syn_trn < dcr_syn_hol: { np .mean (dcr_syn_trn < dcr_syn_hol ):.1%} " )
96
+ _LOG .info (f"share of nndr_syn_trn < nndr_syn_hol: { np .mean (nndr_syn_trn < nndr_syn_hol ):.1%} " )
86
97
_LOG .info (f"share of dcr_syn_trn > dcr_syn_hol: { np .mean (dcr_syn_trn > dcr_syn_hol ):.1%} " )
87
-
98
+ _LOG . info ( f"share of nndr_syn_trn > nndr_syn_hol: { np . mean ( nndr_syn_trn > nndr_syn_hol ):.1% } " )
88
99
if dcr_trn_hol is not None :
89
- dcr_trn_hol_deciles = np .round (np .quantile (dcr_trn_hol , np .linspace (0 , 1 , 11 )), 3 )
90
- _LOG .info (f"DCR deciles for training to holdout: { dcr_trn_hol_deciles } " )
100
+ _LOG .info (f"DCR deciles for training to holdout: { deciles (dcr_trn_hol )} " )
101
+ _LOG .info (f"NNDR deciles for training to holdout: { deciles (nndr_trn_hol )} " )
102
+ return {
103
+ "dcr_syn_trn" : dcr_syn_trn ,
104
+ "nndr_syn_trn" : nndr_syn_trn ,
105
+ "dcr_syn_hol" : dcr_syn_hol ,
106
+ "nndr_syn_hol" : nndr_syn_hol ,
107
+ "dcr_trn_hol" : dcr_trn_hol ,
108
+ "nndr_trn_hol" : nndr_trn_hol ,
109
+ }
91
110
92
- return dcr_syn_trn , dcr_syn_hol , dcr_trn_hol
93
111
112
+ def plot_distances (plot_title : str , distances : dict [str , np .ndarray ]) -> go .Figure :
113
+ dcr_syn_trn = distances ["dcr_syn_trn" ]
114
+ dcr_syn_hol = distances ["dcr_syn_hol" ]
115
+ dcr_trn_hol = distances ["dcr_trn_hol" ]
116
+ nndr_syn_trn = distances ["nndr_syn_trn" ]
117
+ nndr_syn_hol = distances ["nndr_syn_hol" ]
118
+ nndr_trn_hol = distances ["nndr_trn_hol" ]
94
119
95
- def plot_distances (
96
- plot_title : str , dcr_syn_trn : np .ndarray , dcr_syn_hol : np .ndarray | None , dcr_trn_hol : np .ndarray | None
97
- ) -> go .Figure :
98
- # calculate quantiles
120
+ # calculate quantiles for DCR
99
121
y = np .linspace (0 , 1 , 101 )
100
- x_syn_trn = np .quantile (dcr_syn_trn , y )
122
+
123
+ # Calculate max values to use later
124
+ max_dcr_syn_trn = np .max (dcr_syn_trn )
125
+ max_dcr_syn_hol = None if dcr_syn_hol is None else np .max (dcr_syn_hol )
126
+ max_dcr_trn_hol = None if dcr_trn_hol is None else np .max (dcr_trn_hol )
127
+ max_nndr_syn_trn = np .max (nndr_syn_trn )
128
+ max_nndr_syn_hol = None if nndr_syn_hol is None else np .max (nndr_syn_hol )
129
+ max_nndr_trn_hol = None if nndr_trn_hol is None else np .max (nndr_trn_hol )
130
+
131
+ # Ensure first point is always at x=0 for all lines
132
+ # and last point is at the maximum x value with y=1
133
+ x_dcr_syn_trn = np .concatenate ([[0 ], np .quantile (dcr_syn_trn , y [1 :- 1 ]), [max_dcr_syn_trn ]])
101
134
if dcr_syn_hol is not None :
102
- x_syn_hol = np .quantile (dcr_syn_hol , y )
135
+ x_dcr_syn_hol = np .concatenate ([[ 0 ], np . quantile (dcr_syn_hol , y [ 1 : - 1 ]), [ max_dcr_syn_hol ]] )
103
136
else :
104
- x_syn_hol = None
137
+ x_dcr_syn_hol = None
105
138
106
139
if dcr_trn_hol is not None :
107
- x_trn_hol = np .quantile (dcr_trn_hol , y )
140
+ x_dcr_trn_hol = np .concatenate ([[ 0 ], np . quantile (dcr_trn_hol , y [ 1 : - 1 ]), [ max_dcr_trn_hol ]] )
108
141
else :
109
- x_trn_hol = None
142
+ x_dcr_trn_hol = None
143
+
144
+ # calculate quantiles for NNDR
145
+ x_nndr_syn_trn = np .concatenate ([[0 ], np .quantile (nndr_syn_trn , y [1 :- 1 ]), [max_nndr_syn_trn ]])
146
+ if nndr_syn_hol is not None :
147
+ x_nndr_syn_hol = np .concatenate ([[0 ], np .quantile (nndr_syn_hol , y [1 :- 1 ]), [max_nndr_syn_hol ]])
148
+ else :
149
+ x_nndr_syn_hol = None
150
+
151
+ if nndr_trn_hol is not None :
152
+ x_nndr_trn_hol = np .concatenate ([[0 ], np .quantile (nndr_trn_hol , y [1 :- 1 ]), [max_nndr_trn_hol ]])
153
+ else :
154
+ x_nndr_trn_hol = None
155
+
156
+ # Adjust y to match the new x arrays with the added 0 and 1 points
157
+ y = np .concatenate ([[0 ], y [1 :- 1 ], [1 ]])
110
158
111
159
# prepare layout
112
160
layout = go .Layout (
@@ -120,80 +168,132 @@ def plot_distances(
120
168
plot_bgcolor = CHARTS_COLORS ["background" ],
121
169
autosize = True ,
122
170
height = 500 ,
123
- margin = dict (l = 20 , r = 20 , b = 20 , t = 40 , pad = 5 ),
171
+ margin = dict (l = 20 , r = 20 , b = 20 , t = 60 , pad = 5 ),
124
172
showlegend = True ,
125
- yaxis = dict (
126
- showticklabels = False ,
127
- zeroline = True ,
128
- zerolinewidth = 1 ,
129
- zerolinecolor = "#999999" ,
130
- rangemode = "tozero" ,
173
+ )
174
+
175
+ # Create a figure with two subplots side by side
176
+ fig = go .Figure (layout = layout ).set_subplots (
177
+ rows = 1 ,
178
+ cols = 2 ,
179
+ horizontal_spacing = 0.05 ,
180
+ subplot_titles = ("Distance to Closest Record (DCR)" , "Nearest Neighbor Distance Ratio (NNDR)" ),
181
+ )
182
+ fig .update_annotations (font_size = 12 )
183
+
184
+ # Configure axes for both subplots
185
+ for i in range (1 , 3 ):
186
+ fig .update_xaxes (
187
+ col = i ,
131
188
showline = True ,
132
189
linewidth = 1 ,
133
190
linecolor = "#999999" ,
134
- ),
135
- yaxis2 = dict (
136
- overlaying = "y" ,
137
- side = "right" ,
191
+ hoverformat = ".3f" ,
192
+ )
193
+
194
+ # Only show y-axis on the right side with percentage labels
195
+ fig .update_yaxes (
196
+ col = i ,
138
197
tickformat = ".0%" ,
139
198
showgrid = False ,
140
- range = [0 , 1 ],
199
+ range = [- 0.01 , 1.01 ],
141
200
showline = True ,
142
201
linewidth = 1 ,
143
202
linecolor = "#999999" ,
203
+ side = "right" ,
204
+ tickvals = [0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ],
205
+ )
206
+
207
+ # Add traces for DCR plot (left subplot)
208
+ # training vs holdout (light gray)
209
+ if x_dcr_trn_hol is not None :
210
+ fig .add_trace (
211
+ go .Scatter (
212
+ mode = "lines" ,
213
+ x = x_dcr_trn_hol ,
214
+ y = y ,
215
+ name = "Training vs. Holdout Data" ,
216
+ line = dict (color = "#999999" , width = 5 ),
217
+ showlegend = True ,
218
+ ),
219
+ row = 1 ,
220
+ col = 1 ,
221
+ )
222
+
223
+ # synthetic vs holdout (gray)
224
+ if x_dcr_syn_hol is not None :
225
+ fig .add_trace (
226
+ go .Scatter (
227
+ mode = "lines" ,
228
+ x = x_dcr_syn_hol ,
229
+ y = y ,
230
+ name = "Synthetic vs. Holdout Data" ,
231
+ line = dict (color = "#666666" , width = 5 ),
232
+ showlegend = True ,
233
+ ),
234
+ row = 1 ,
235
+ col = 1 ,
236
+ )
237
+
238
+ # synthetic vs training (green)
239
+ fig .add_trace (
240
+ go .Scatter (
241
+ mode = "lines" ,
242
+ x = x_dcr_syn_trn ,
243
+ y = y ,
244
+ name = "Synthetic vs. Training Data" ,
245
+ line = dict (color = "#24db96" , width = 5 ),
246
+ showlegend = True ,
144
247
),
145
- xaxis = dict (
146
- showline = True ,
147
- linewidth = 1 ,
148
- linecolor = "#999999" ,
149
- hoverformat = ".3f" ,
150
- ),
248
+ row = 1 ,
249
+ col = 1 ,
151
250
)
152
- fig = go .Figure (layout = layout )
153
-
154
- traces = []
155
251
252
+ # Add traces for NNDR plot (right subplot)
156
253
# training vs holdout (light gray)
157
- if x_trn_hol is not None :
158
- traces . append (
254
+ if x_nndr_trn_hol is not None :
255
+ fig . add_trace (
159
256
go .Scatter (
160
257
mode = "lines" ,
161
- x = x_trn_hol ,
258
+ x = x_nndr_trn_hol ,
162
259
y = y ,
163
260
name = "Training vs. Holdout Data" ,
164
261
line = dict (color = "#999999" , width = 5 ),
165
- yaxis = "y2" ,
166
- )
262
+ showlegend = False ,
263
+ ),
264
+ row = 1 ,
265
+ col = 2 ,
167
266
)
168
267
169
268
# synthetic vs holdout (gray)
170
- if x_syn_hol is not None :
171
- traces . append (
269
+ if x_nndr_syn_hol is not None :
270
+ fig . add_trace (
172
271
go .Scatter (
173
272
mode = "lines" ,
174
- x = x_syn_hol ,
273
+ x = x_nndr_syn_hol ,
175
274
y = y ,
176
275
name = "Synthetic vs. Holdout Data" ,
177
276
line = dict (color = "#666666" , width = 5 ),
178
- yaxis = "y2" ,
179
- )
277
+ showlegend = False ,
278
+ ),
279
+ row = 1 ,
280
+ col = 2 ,
180
281
)
181
282
182
283
# synthetic vs training (green)
183
- traces . append (
284
+ fig . add_trace (
184
285
go .Scatter (
185
286
mode = "lines" ,
186
- x = x_syn_trn ,
287
+ x = x_nndr_syn_trn ,
187
288
y = y ,
188
289
name = "Synthetic vs. Training Data" ,
189
290
line = dict (color = "#24db96" , width = 5 ),
190
- yaxis = "y2" ,
191
- )
291
+ showlegend = False ,
292
+ ),
293
+ row = 1 ,
294
+ col = 2 ,
192
295
)
193
296
194
- for trace in traces :
195
- fig .add_trace (trace )
196
-
197
297
fig .update_layout (
198
298
legend = dict (
199
299
orientation = "h" ,
@@ -210,12 +310,11 @@ def plot_distances(
210
310
211
311
212
312
def plot_store_distances (
213
- dcr_syn_trn : np .ndarray ,
214
- dcr_syn_hol : np .ndarray | None ,
215
- dcr_trn_hol : np .ndarray | None ,
313
+ distances : dict [str , np .ndarray ],
216
314
workspace : TemporaryWorkspace ,
217
315
) -> None :
218
316
fig = plot_distances (
219
- "Cumulative Distributions of Distance to Closest Records (DCR)" , dcr_syn_trn , dcr_syn_hol , dcr_trn_hol
317
+ "Cumulative Distributions of Distance Metrics" ,
318
+ distances ,
220
319
)
221
320
workspace .store_figure_html (fig , "distances_dcr" )
0 commit comments