Skip to content

Commit 15f0da9

Browse files
authored
feat: plot NNDRs, plus report 10th smallest values as metrics (#169)
1 parent 25558f5 commit 15f0da9

File tree

6 files changed

+242
-111
lines changed

6 files changed

+242
-111
lines changed

mostlyai/qa/_distances.py

Lines changed: 172 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,11 @@
2828
_LOG = logging.getLogger(__name__)
2929

3030

31-
def calculate_dcrs(data: np.ndarray | None, query: np.ndarray | None) -> np.ndarray | None:
31+
def calculate_dcrs_nndrs(
32+
data: np.ndarray | None, query: np.ndarray | None
33+
) -> tuple[np.ndarray | None, np.ndarray | None]:
3234
"""
33-
Calculate Distance to Closest Records (DCRs).
35+
Calculate Distance to Closest Records (DCRs) and Nearest Neighbor Distance Ratios (NNDRs).
3436
3537
Args:
3638
data: Embeddings of the training data.
@@ -39,19 +41,21 @@ def calculate_dcrs(data: np.ndarray | None, query: np.ndarray | None) -> np.ndar
3941
Returns:
4042
"""
4143
if data is None or query is None:
42-
return None
44+
return None, None
4345
# sort data by first dimension to enforce deterministic results
4446
data = data[data[:, 0].argsort()]
4547
_LOG.info(f"calculate DCRs for {data.shape=} and {query.shape=}")
46-
index = NearestNeighbors(n_neighbors=1, algorithm="auto", metric="cosine", n_jobs=min(cpu_count() - 1, 16))
48+
index = NearestNeighbors(n_neighbors=2, algorithm="auto", metric="cosine", n_jobs=min(cpu_count() - 1, 16))
4749
index.fit(data)
4850
dcrs, _ = index.kneighbors(query)
49-
return dcrs[:, 0]
51+
dcr = dcrs[:, 0]
52+
nndr = (dcrs[:, 0] + 1e-8) / (dcrs[:, 1] + 1e-8)
53+
return dcr, nndr
5054

5155

5256
def calculate_distances(
5357
*, syn_embeds: np.ndarray, trn_embeds: np.ndarray, hol_embeds: np.ndarray | None
54-
) -> tuple[np.ndarray, np.ndarray | None, np.ndarray | None]:
58+
) -> dict[str, np.ndarray]:
5559
"""
5660
Calculates distances to the closest records (DCR).
5761
@@ -61,52 +65,96 @@ def calculate_distances(
6165
hol_embeds: Embeddings of holdout data.
6266
6367
Returns:
64-
Tuple containing:
68+
Dictionary containing:
6569
- dcr_syn_trn: DCR for synthetic to training.
6670
- dcr_syn_hol: DCR for synthetic to holdout.
6771
- dcr_trn_hol: DCR for training to holdout.
72+
- nndr_syn_trn: NNDR for synthetic to training.
73+
- nndr_syn_hol: NNDR for synthetic to holdout.
74+
- nndr_trn_hol: NNDR for training to holdout.
6875
"""
6976
if hol_embeds is not None:
7077
assert trn_embeds.shape == hol_embeds.shape
7178

72-
# calculate DCR for synthetic to training
73-
dcr_syn_trn = calculate_dcrs(data=trn_embeds, query=syn_embeds)
74-
# calculate DCR for synthetic to holdout
75-
dcr_syn_hol = calculate_dcrs(data=hol_embeds, query=syn_embeds)
76-
# calculate DCR for holdout to training
77-
dcr_trn_hol = calculate_dcrs(data=trn_embeds, query=hol_embeds)
79+
# calculate DCR / NNDR for synthetic to training
80+
dcr_syn_trn, nndr_syn_trn = calculate_dcrs_nndrs(data=trn_embeds, query=syn_embeds)
81+
# calculate DCR / NNDR for synthetic to holdout
82+
dcr_syn_hol, nndr_syn_hol = calculate_dcrs_nndrs(data=hol_embeds, query=syn_embeds)
83+
# calculate DCR / NNDR for holdout to training
84+
dcr_trn_hol, nndr_trn_hol = calculate_dcrs_nndrs(data=trn_embeds, query=hol_embeds)
7885

79-
dcr_syn_trn_deciles = np.round(np.quantile(dcr_syn_trn, np.linspace(0, 1, 11)), 3)
80-
_LOG.info(f"DCR deciles for synthetic to training: {dcr_syn_trn_deciles}")
86+
# log statistics
87+
def deciles(x):
88+
return np.round(np.quantile(x, np.linspace(0, 1, 11)), 3)
89+
90+
_LOG.info(f"DCR deciles for synthetic to training: {deciles(dcr_syn_trn)}")
91+
_LOG.info(f"NNDR deciles for synthetic to training: {deciles(nndr_syn_trn)}")
8192
if dcr_syn_hol is not None:
82-
dcr_syn_hol_deciles = np.round(np.quantile(dcr_syn_hol, np.linspace(0, 1, 11)), 3)
83-
_LOG.info(f"DCR deciles for synthetic to holdout: {dcr_syn_hol_deciles}")
84-
# calculate share of dcr_syn_trn != dcr_syn_hol
93+
_LOG.info(f"DCR deciles for synthetic to holdout: {deciles(dcr_syn_hol)}")
94+
_LOG.info(f"NNDR deciles for synthetic to holdout: {deciles(nndr_syn_hol)}")
8595
_LOG.info(f"share of dcr_syn_trn < dcr_syn_hol: {np.mean(dcr_syn_trn < dcr_syn_hol):.1%}")
96+
_LOG.info(f"share of nndr_syn_trn < nndr_syn_hol: {np.mean(nndr_syn_trn < nndr_syn_hol):.1%}")
8697
_LOG.info(f"share of dcr_syn_trn > dcr_syn_hol: {np.mean(dcr_syn_trn > dcr_syn_hol):.1%}")
87-
98+
_LOG.info(f"share of nndr_syn_trn > nndr_syn_hol: {np.mean(nndr_syn_trn > nndr_syn_hol):.1%}")
8899
if dcr_trn_hol is not None:
89-
dcr_trn_hol_deciles = np.round(np.quantile(dcr_trn_hol, np.linspace(0, 1, 11)), 3)
90-
_LOG.info(f"DCR deciles for training to holdout: {dcr_trn_hol_deciles}")
100+
_LOG.info(f"DCR deciles for training to holdout: {deciles(dcr_trn_hol)}")
101+
_LOG.info(f"NNDR deciles for training to holdout: {deciles(nndr_trn_hol)}")
102+
return {
103+
"dcr_syn_trn": dcr_syn_trn,
104+
"nndr_syn_trn": nndr_syn_trn,
105+
"dcr_syn_hol": dcr_syn_hol,
106+
"nndr_syn_hol": nndr_syn_hol,
107+
"dcr_trn_hol": dcr_trn_hol,
108+
"nndr_trn_hol": nndr_trn_hol,
109+
}
91110

92-
return dcr_syn_trn, dcr_syn_hol, dcr_trn_hol
93111

112+
def plot_distances(plot_title: str, distances: dict[str, np.ndarray]) -> go.Figure:
113+
dcr_syn_trn = distances["dcr_syn_trn"]
114+
dcr_syn_hol = distances["dcr_syn_hol"]
115+
dcr_trn_hol = distances["dcr_trn_hol"]
116+
nndr_syn_trn = distances["nndr_syn_trn"]
117+
nndr_syn_hol = distances["nndr_syn_hol"]
118+
nndr_trn_hol = distances["nndr_trn_hol"]
94119

95-
def plot_distances(
96-
plot_title: str, dcr_syn_trn: np.ndarray, dcr_syn_hol: np.ndarray | None, dcr_trn_hol: np.ndarray | None
97-
) -> go.Figure:
98-
# calculate quantiles
120+
# calculate quantiles for DCR
99121
y = np.linspace(0, 1, 101)
100-
x_syn_trn = np.quantile(dcr_syn_trn, y)
122+
123+
# Calculate max values to use later
124+
max_dcr_syn_trn = np.max(dcr_syn_trn)
125+
max_dcr_syn_hol = None if dcr_syn_hol is None else np.max(dcr_syn_hol)
126+
max_dcr_trn_hol = None if dcr_trn_hol is None else np.max(dcr_trn_hol)
127+
max_nndr_syn_trn = np.max(nndr_syn_trn)
128+
max_nndr_syn_hol = None if nndr_syn_hol is None else np.max(nndr_syn_hol)
129+
max_nndr_trn_hol = None if nndr_trn_hol is None else np.max(nndr_trn_hol)
130+
131+
# Ensure first point is always at x=0 for all lines
132+
# and last point is at the maximum x value with y=1
133+
x_dcr_syn_trn = np.concatenate([[0], np.quantile(dcr_syn_trn, y[1:-1]), [max_dcr_syn_trn]])
101134
if dcr_syn_hol is not None:
102-
x_syn_hol = np.quantile(dcr_syn_hol, y)
135+
x_dcr_syn_hol = np.concatenate([[0], np.quantile(dcr_syn_hol, y[1:-1]), [max_dcr_syn_hol]])
103136
else:
104-
x_syn_hol = None
137+
x_dcr_syn_hol = None
105138

106139
if dcr_trn_hol is not None:
107-
x_trn_hol = np.quantile(dcr_trn_hol, y)
140+
x_dcr_trn_hol = np.concatenate([[0], np.quantile(dcr_trn_hol, y[1:-1]), [max_dcr_trn_hol]])
108141
else:
109-
x_trn_hol = None
142+
x_dcr_trn_hol = None
143+
144+
# calculate quantiles for NNDR
145+
x_nndr_syn_trn = np.concatenate([[0], np.quantile(nndr_syn_trn, y[1:-1]), [max_nndr_syn_trn]])
146+
if nndr_syn_hol is not None:
147+
x_nndr_syn_hol = np.concatenate([[0], np.quantile(nndr_syn_hol, y[1:-1]), [max_nndr_syn_hol]])
148+
else:
149+
x_nndr_syn_hol = None
150+
151+
if nndr_trn_hol is not None:
152+
x_nndr_trn_hol = np.concatenate([[0], np.quantile(nndr_trn_hol, y[1:-1]), [max_nndr_trn_hol]])
153+
else:
154+
x_nndr_trn_hol = None
155+
156+
# Adjust y to match the new x arrays with the added 0 and 1 points
157+
y = np.concatenate([[0], y[1:-1], [1]])
110158

111159
# prepare layout
112160
layout = go.Layout(
@@ -120,80 +168,132 @@ def plot_distances(
120168
plot_bgcolor=CHARTS_COLORS["background"],
121169
autosize=True,
122170
height=500,
123-
margin=dict(l=20, r=20, b=20, t=40, pad=5),
171+
margin=dict(l=20, r=20, b=20, t=60, pad=5),
124172
showlegend=True,
125-
yaxis=dict(
126-
showticklabels=False,
127-
zeroline=True,
128-
zerolinewidth=1,
129-
zerolinecolor="#999999",
130-
rangemode="tozero",
173+
)
174+
175+
# Create a figure with two subplots side by side
176+
fig = go.Figure(layout=layout).set_subplots(
177+
rows=1,
178+
cols=2,
179+
horizontal_spacing=0.05,
180+
subplot_titles=("Distance to Closest Record (DCR)", "Nearest Neighbor Distance Ratio (NNDR)"),
181+
)
182+
fig.update_annotations(font_size=12)
183+
184+
# Configure axes for both subplots
185+
for i in range(1, 3):
186+
fig.update_xaxes(
187+
col=i,
131188
showline=True,
132189
linewidth=1,
133190
linecolor="#999999",
134-
),
135-
yaxis2=dict(
136-
overlaying="y",
137-
side="right",
191+
hoverformat=".3f",
192+
)
193+
194+
# Only show y-axis on the right side with percentage labels
195+
fig.update_yaxes(
196+
col=i,
138197
tickformat=".0%",
139198
showgrid=False,
140-
range=[0, 1],
199+
range=[-0.01, 1.01],
141200
showline=True,
142201
linewidth=1,
143202
linecolor="#999999",
203+
side="right",
204+
tickvals=[0, 0.2, 0.4, 0.6, 0.8, 1.0],
205+
)
206+
207+
# Add traces for DCR plot (left subplot)
208+
# training vs holdout (light gray)
209+
if x_dcr_trn_hol is not None:
210+
fig.add_trace(
211+
go.Scatter(
212+
mode="lines",
213+
x=x_dcr_trn_hol,
214+
y=y,
215+
name="Training vs. Holdout Data",
216+
line=dict(color="#999999", width=5),
217+
showlegend=True,
218+
),
219+
row=1,
220+
col=1,
221+
)
222+
223+
# synthetic vs holdout (gray)
224+
if x_dcr_syn_hol is not None:
225+
fig.add_trace(
226+
go.Scatter(
227+
mode="lines",
228+
x=x_dcr_syn_hol,
229+
y=y,
230+
name="Synthetic vs. Holdout Data",
231+
line=dict(color="#666666", width=5),
232+
showlegend=True,
233+
),
234+
row=1,
235+
col=1,
236+
)
237+
238+
# synthetic vs training (green)
239+
fig.add_trace(
240+
go.Scatter(
241+
mode="lines",
242+
x=x_dcr_syn_trn,
243+
y=y,
244+
name="Synthetic vs. Training Data",
245+
line=dict(color="#24db96", width=5),
246+
showlegend=True,
144247
),
145-
xaxis=dict(
146-
showline=True,
147-
linewidth=1,
148-
linecolor="#999999",
149-
hoverformat=".3f",
150-
),
248+
row=1,
249+
col=1,
151250
)
152-
fig = go.Figure(layout=layout)
153-
154-
traces = []
155251

252+
# Add traces for NNDR plot (right subplot)
156253
# training vs holdout (light gray)
157-
if x_trn_hol is not None:
158-
traces.append(
254+
if x_nndr_trn_hol is not None:
255+
fig.add_trace(
159256
go.Scatter(
160257
mode="lines",
161-
x=x_trn_hol,
258+
x=x_nndr_trn_hol,
162259
y=y,
163260
name="Training vs. Holdout Data",
164261
line=dict(color="#999999", width=5),
165-
yaxis="y2",
166-
)
262+
showlegend=False,
263+
),
264+
row=1,
265+
col=2,
167266
)
168267

169268
# synthetic vs holdout (gray)
170-
if x_syn_hol is not None:
171-
traces.append(
269+
if x_nndr_syn_hol is not None:
270+
fig.add_trace(
172271
go.Scatter(
173272
mode="lines",
174-
x=x_syn_hol,
273+
x=x_nndr_syn_hol,
175274
y=y,
176275
name="Synthetic vs. Holdout Data",
177276
line=dict(color="#666666", width=5),
178-
yaxis="y2",
179-
)
277+
showlegend=False,
278+
),
279+
row=1,
280+
col=2,
180281
)
181282

182283
# synthetic vs training (green)
183-
traces.append(
284+
fig.add_trace(
184285
go.Scatter(
185286
mode="lines",
186-
x=x_syn_trn,
287+
x=x_nndr_syn_trn,
187288
y=y,
188289
name="Synthetic vs. Training Data",
189290
line=dict(color="#24db96", width=5),
190-
yaxis="y2",
191-
)
291+
showlegend=False,
292+
),
293+
row=1,
294+
col=2,
192295
)
193296

194-
for trace in traces:
195-
fig.add_trace(trace)
196-
197297
fig.update_layout(
198298
legend=dict(
199299
orientation="h",
@@ -210,12 +310,11 @@ def plot_distances(
210310

211311

212312
def plot_store_distances(
213-
dcr_syn_trn: np.ndarray,
214-
dcr_syn_hol: np.ndarray | None,
215-
dcr_trn_hol: np.ndarray | None,
313+
distances: dict[str, np.ndarray],
216314
workspace: TemporaryWorkspace,
217315
) -> None:
218316
fig = plot_distances(
219-
"Cumulative Distributions of Distance to Closest Records (DCR)", dcr_syn_trn, dcr_syn_hol, dcr_trn_hol
317+
"Cumulative Distributions of Distance Metrics",
318+
distances,
220319
)
221320
workspace.store_figure_html(fig, "distances_dcr")

mostlyai/qa/assets/html/report_template.html

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ <h2 id="distances" class="anchor">Distances</h2>
411411
{% endif %}
412412
</tr>
413413
<tr>
414-
<td>Average Distances</td>
414+
<td>DCR Average</td>
415415
<td>{{ "{:.3f}".format(metrics.distances.dcr_training) }}</td>
416416
{% if metrics.distances.dcr_holdout is not none %}
417417
<td><small style="color: #666666;">{{ "{:.3f}".format(metrics.distances.dcr_holdout) }}</small></td>
@@ -422,16 +422,25 @@ <h2 id="distances" class="anchor">Distances</h2>
422422
<tr>
423423
<td>DCR Share</td>
424424
<td>{{ "{:.1%}".format(metrics.distances.dcr_share) }}</td>
425-
<td></td>
425+
<td><small style="color: #666666;">{{ "{:.1%}".format(1 - metrics.distances.dcr_share) }}</small></td>
426426
<td></td>
427427
</tr>
428428
{% endif %}
429+
<tr>
430+
<td>NNDR Min10</td>
431+
<td>{{ "{:.2e}".format(metrics.distances.nndr_training) if metrics.distances.nndr_training < 0.01 else "{:.3f}".format(metrics.distances.nndr_training) }}</td>
432+
{% if metrics.distances.nndr_holdout is not none %}
433+
<td><small style="color: #666666;">{{ "{:.2e}".format(metrics.distances.nndr_holdout) if metrics.distances.nndr_holdout < 0.01 else "{:.3f}".format(metrics.distances.nndr_holdout) }}</small></td>
434+
<td></td>
435+
{% endif %}
436+
</tr>
429437
</tbody>
430438
</table>
431-
<br />
432-
<div class="white-box p-3">
439+
</div>
440+
</div>
441+
<div class="row">
442+
<div class="white-box p-3">
433443
{{ distances_dcr_html_chart }}
434-
</div>
435444
</div>
436445
</div>
437446
<br />

0 commit comments

Comments
 (0)