Skip to content

Commit 7b7fa49

Browse files
feat: add trivariate accuracy metric (#179)
Co-authored-by: Lukasz Kolodziejczyk <[email protected]>
1 parent d439ad6 commit 7b7fa49

9 files changed

+178
-38
lines changed

mostlyai/qa/_accuracy.py

+58-6
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import hashlib
1717
import logging
1818
import math
19+
import time
1920
from typing import Any, Literal
2021
from collections.abc import Callable, Iterable
2122

@@ -30,6 +31,7 @@
3031
CHARTS_COLORS,
3132
CHARTS_FONTS,
3233
EMPTY_BIN,
34+
MAX_TRIVARIATES,
3335
NA_BIN,
3436
MIN_RARE_CAT_PROTECTION,
3537
OTHER_BIN,
@@ -58,9 +60,9 @@ def calculate_univariates(
5860
"""
5961
Calculates univariate accuracies for all target columns.
6062
"""
61-
_LOG.info("calculate univariates")
62-
63+
t0 = time.time()
6364
tgt_cols = [c for c in ori_bin.columns if c.startswith(TGT_COLUMN)]
65+
6466
accuracies = pd.DataFrame({"column": tgt_cols})
6567
with parallel_config("loky", n_jobs=min(16, max(1, cpu_count() - 1))):
6668
results = Parallel()(
@@ -71,6 +73,9 @@ def calculate_univariates(
7173
for _, row in accuracies.iterrows()
7274
)
7375
accuracies["accuracy"], accuracies["accuracy_max"] = zip(*results)
76+
77+
_LOG.info(f"calculated univariates for {len(tgt_cols)} columns in {time.time() - t0:.2f} seconds")
78+
7479
return accuracies
7580

7681

@@ -87,7 +92,7 @@ def calculate_bivariates(
8792
For each such column pair, value pair frequencies
8893
are calculated both for training and synthetic data.
8994
"""
90-
_LOG.info("calculate bivariates")
95+
t0 = time.time()
9196

9297
# the result for symmetric pairs is the same, so we only calculate one of them
9398
# later, we append copy results for symmetric pairs
@@ -107,7 +112,6 @@ def calculate_bivariates(
107112
else:
108113
# enforce consistent columns
109114
accuracies[["accuracy", "accuracy_max"]] = None
110-
# ensure required number of progress messages are sent
111115

112116
accuracies = pd.concat(
113117
[
@@ -117,6 +121,8 @@ def calculate_bivariates(
117121
axis=0,
118122
).reset_index(drop=True)
119123

124+
_LOG.info(f"calculated bivariate accuracies for {len(accuracies)} combinations in {time.time() - t0:.2f} seconds")
125+
120126
return accuracies
121127

122128

@@ -169,6 +175,49 @@ def calculate_bivariate_columns(ori_bin: pd.DataFrame, append_symetric: bool = T
169175
return columns_df
170176

171177

178+
def calculate_trivariates(ori_bin: pd.DataFrame, syn_bin: pd.DataFrame) -> pd.DataFrame:
179+
"""
180+
Calculates trivariate accuracies.
181+
"""
182+
t0 = time.time()
183+
184+
accuracies = calculate_trivariate_columns(ori_bin)
185+
186+
# calculate trivariates if there is at least one pair
187+
if len(accuracies) > 0:
188+
with parallel_config("loky", n_jobs=min(16, max(1, cpu_count() - 1))):
189+
results = Parallel()(
190+
delayed(calculate_accuracy)(
191+
ori_bin_cols=ori_bin[[row["col1"], row["col2"], row["col3"]]],
192+
syn_bin_cols=syn_bin[[row["col1"], row["col2"], row["col3"]]],
193+
)
194+
for _, row in accuracies.iterrows()
195+
)
196+
accuracies["accuracy"], accuracies["accuracy_max"] = zip(*results)
197+
else:
198+
# enforce consistent columns
199+
accuracies[["accuracy", "accuracy_max"]] = None
200+
201+
_LOG.info(f"calculated trivariate accuracies for {len(accuracies)} combinations in {time.time() - t0:.2f} seconds")
202+
203+
return accuracies
204+
205+
206+
def calculate_trivariate_columns(ori_bin: pd.DataFrame) -> pd.DataFrame:
207+
"""
208+
Creates DataFrame with all column-triples subject to trivariate analysis.
209+
"""
210+
tgt_cols = [c for c in ori_bin.columns if c.startswith(TGT_COLUMN_PREFIX)]
211+
columns_df = pd.DataFrame({"col1": tgt_cols})
212+
columns_df = pd.merge(columns_df, pd.DataFrame({"col2": tgt_cols}), how="cross")
213+
columns_df = pd.merge(columns_df, pd.DataFrame({"col3": tgt_cols}), how="cross")
214+
columns_df = columns_df.loc[columns_df.col1 < columns_df.col2]
215+
columns_df = columns_df.loc[columns_df.col1 < columns_df.col3]
216+
columns_df = columns_df.loc[columns_df.col2 < columns_df.col3]
217+
columns_df = columns_df.sample(frac=1).head(n=MAX_TRIVARIATES)
218+
return columns_df
219+
220+
172221
def calculate_expected_l1_multinomial(probs: list[float], n_1: int, n_2: int) -> np.float64:
173222
"""
174223
Calculate expected L1 distance for two multinomial samples of size `n_1` and `n_2` that follow `probs`.
@@ -349,7 +398,7 @@ def calculate_bin_counts(
349398
"""
350399
Calculates counts of unique values in each bin.
351400
"""
352-
_LOG.info("calculate bin counts")
401+
t0 = time.time()
353402
with parallel_config("loky", n_jobs=min(16, max(1, cpu_count() - 1))):
354403
results = Parallel()(
355404
delayed(bin_count_uni)(
@@ -359,8 +408,10 @@ def calculate_bin_counts(
359408
for col, values in binned.items()
360409
)
361410
bin_cnts_uni = dict(results)
411+
_LOG.info(f"calculated univariate bin counts for {len(binned.columns)} columns in {time.time() - t0:.2f} seconds")
362412

363-
biv_cols = calculate_bivariate_columns(binned)
413+
t0 = time.time()
414+
biv_cols = calculate_bivariate_columns(binned, append_symetric=True)
364415
with parallel_config("loky", n_jobs=min(16, max(1, cpu_count() - 1))):
365416
results = Parallel()(
366417
delayed(bin_count_biv)(
@@ -372,6 +423,7 @@ def calculate_bin_counts(
372423
for _, row in biv_cols.iterrows()
373424
)
374425
bin_cnts_biv = dict(results)
426+
_LOG.info(f"calculated bivariate bin counts for {len(biv_cols)} combinations in {time.time() - t0:.2f} seconds")
375427

376428
return bin_cnts_uni, bin_cnts_biv
377429

mostlyai/qa/_common.py

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
MAX_BIVARIATE_TGT_PLOTS = 300
2929
MAX_BIVARIATE_CTX_PLOTS = 60
3030
MAX_BIVARIATE_NXT_PLOTS = 60
31+
MAX_TRIVARIATES = 10_000
3132

3233
NA_BIN = "(n/a)"
3334
OTHER_BIN = "(other)"

mostlyai/qa/_filesystem.py

+10
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def __init__(self, path: str | Path):
110110
self.bins_dir = self.path / "bins"
111111
self.univariate_accuracies_path = self.path / "univariate_accuracies.parquet"
112112
self.bivariate_accuracies_path = self.path / "bivariate_accuracies.parquet"
113+
self.trivariate_accuracies_path = self.path / "trivariate_accuracies.parquet"
113114
self.numeric_kdes_uni_dir = self.path / "numeric_kdes_uni"
114115
self.categorical_counts_uni_dir = self.path / "categorical_counts_uni"
115116
self.bin_counts_uni_path = self.path / "bin_counts_uni.parquet"
@@ -203,6 +204,15 @@ def load_bivariate_accuracies(self) -> pd.DataFrame:
203204
df["col2"] = df["col2"].str.replace(_OLD_COL_PREFIX, _NEW_COL_PREFIX, regex=True)
204205
return df
205206

207+
def store_trivariate_accuracies(self, trivariates: pd.DataFrame) -> None:
208+
trivariates.to_parquet(self.trivariate_accuracies_path)
209+
210+
def load_trivariate_accuracies(self) -> pd.DataFrame:
211+
if not self.trivariate_accuracies_path.exists():
212+
return pd.DataFrame(columns=["col1", "col2", "col3", "accuracy", "accuracy_max"])
213+
df = pd.read_parquet(self.trivariate_accuracies_path)
214+
return df
215+
206216
def store_numeric_uni_kdes(self, trn_kdes: dict[str, pd.Series]) -> None:
207217
trn_kdes = pd.DataFrame(
208218
[(column, list(xy.index), list(xy.values)) for column, xy in trn_kdes.items()],

mostlyai/qa/_html_report.py

+32-8
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def store_report(
7373
meta: dict,
7474
acc_uni: pd.DataFrame,
7575
acc_biv: pd.DataFrame,
76+
acc_triv: pd.DataFrame,
7677
acc_cats_per_seq: pd.DataFrame,
7778
acc_seqs_per_cat: pd.DataFrame,
7879
corr_trn: pd.DataFrame,
@@ -82,7 +83,9 @@ def store_report(
8283
"""
8384

8485
# summarize accuracies by column for overview table
85-
accuracy_table_by_column = summarize_accuracies_by_column(acc_uni, acc_biv, acc_cats_per_seq, acc_seqs_per_cat)
86+
accuracy_table_by_column = summarize_accuracies_by_column(
87+
acc_uni, acc_biv, acc_triv, acc_cats_per_seq, acc_seqs_per_cat
88+
)
8689
accuracy_table_by_column = accuracy_table_by_column.sort_values("univariate", ascending=False)
8790

8891
acc_uni = filter_uni_acc_for_plotting(acc_uni)
@@ -131,27 +134,48 @@ def store_report(
131134

132135

133136
def summarize_accuracies_by_column(
134-
acc_uni: pd.DataFrame, acc_biv: pd.DataFrame, acc_cats_per_seq: pd.DataFrame, acc_seqs_per_cat: pd.DataFrame
137+
acc_uni: pd.DataFrame,
138+
acc_biv: pd.DataFrame,
139+
acc_triv: pd.DataFrame,
140+
acc_cats_per_seq: pd.DataFrame,
141+
acc_seqs_per_cat: pd.DataFrame,
135142
) -> pd.DataFrame:
136143
"""
137-
Calculates DataFrame that stores per-column univariate, bivariate and coherence accuracies.
144+
Calculates DataFrame that stores per-column univariate, bivariate, trivariate and coherence accuracies.
138145
"""
139146

140147
tbl_acc_uni = acc_uni.rename(columns={"accuracy": "univariate", "accuracy_max": "univariate_max"})
148+
tbl_acc = tbl_acc_uni
149+
141150
tbl_acc_biv = (
142-
acc_biv.loc[acc_biv.type != "nxt"]
143-
.groupby("col1")
144-
.mean(["accuracy", "accuracy_max"])
151+
acc_biv.melt(value_vars=["col1", "col2"], value_name="column", id_vars=["accuracy", "accuracy_max"])
152+
.groupby("column")[["accuracy", "accuracy_max"]]
153+
.mean()
145154
.reset_index()
146155
.rename(
147156
columns={
148-
"col1": "column",
149157
"accuracy": "bivariate",
150158
"accuracy_max": "bivariate_max",
151159
}
152160
)
153161
)
154-
tbl_acc = tbl_acc_uni.merge(tbl_acc_biv, how="left")
162+
if not tbl_acc_biv.empty:
163+
tbl_acc = tbl_acc_uni.merge(tbl_acc_biv, how="left")
164+
165+
tbl_acc_triv = (
166+
acc_triv.melt(value_vars=["col1", "col2", "col3"], value_name="column", id_vars=["accuracy", "accuracy_max"])
167+
.groupby("column")[["accuracy", "accuracy_max"]]
168+
.mean()
169+
.reset_index()
170+
.rename(
171+
columns={
172+
"accuracy": "trivariate",
173+
"accuracy_max": "trivariate_max",
174+
}
175+
)
176+
)
177+
if not tbl_acc_triv.empty:
178+
tbl_acc = tbl_acc.merge(tbl_acc_triv, how="left")
155179

156180
acc_nxt = acc_biv.loc[acc_biv.type == "nxt"]
157181
if not all((acc_nxt.empty, acc_cats_per_seq.empty, acc_seqs_per_cat.empty)):

mostlyai/qa/assets/html/report_template.html

+23-10
Original file line numberDiff line numberDiff line change
@@ -75,24 +75,28 @@ <h1 id="summary"><span>{{ meta.report_title }}</span>{{ meta.report_subtitle }}<
7575
<table class='table'>
7676
<tr><td>Univariate</td>
7777
<td align="right">
78-
{{ "{:.1%}".format(metrics.accuracy.univariate) }}<br />
79-
<small class="muted-text">({{ "{:.1%}".format(metrics.accuracy.univariate_max) }})</small>
78+
{{ "{:.1%}".format(metrics.accuracy.univariate) }}
8079
</td>
8180
</tr>
8281
{% if 'bivariate' in accuracy_table_by_column %}
8382
<tr><td>Bivariate</td>
8483
<td align="right">
85-
{{ "{:.1%}".format(metrics.accuracy.bivariate) }}<br />
86-
<small class="muted-text">({{ "{:.1%}".format(metrics.accuracy.bivariate_max) }})</small>
84+
{{ "{:.1%}".format(metrics.accuracy.bivariate) }}
85+
</td>
86+
</tr>
87+
{% endif %}
88+
{% if 'trivariate' in accuracy_table_by_column %}
89+
<tr><td>Trivariate</td>
90+
<td align="right">
91+
{{ "{:.1%}".format(metrics.accuracy.trivariate) }}
8792
</td>
8893
</tr>
8994
{% endif %}
9095
{% if 'coherence' in accuracy_table_by_column %}
9196
<tr>
9297
<td>Coherence</td>
9398
<td align="right">
94-
{{ "{:.1%}".format(metrics.accuracy.coherence).replace('nan%', '-') }}<br />
95-
<small class="muted-text">({{ "{:.1%}".format(metrics.accuracy.coherence_max).replace('nan%', '-') }})</small>
99+
{{ "{:.1%}".format(metrics.accuracy.coherence).replace('nan%', '-') }}
96100
</td>
97101
</tr>
98102
{% endif %}
@@ -305,6 +309,9 @@ <h2 id="accuracy" class="anchor">Accuracy</h2>
305309
{% if 'bivariate' in accuracy_table_by_column %}
306310
<th>Bivariate</th>
307311
{% endif %}
312+
{% if 'trivariate' in accuracy_table_by_column %}
313+
<th>Trivariate</th>
314+
{% endif %}
308315
{% if 'coherence' in accuracy_table_by_column %}
309316
<th>Coherence</th>
310317
{% endif %}
@@ -318,6 +325,9 @@ <h2 id="accuracy" class="anchor">Accuracy</h2>
318325
{% if 'bivariate' in accuracy_table_by_column %}
319326
<td>{{ "{:.1%}".format(row['bivariate']) }}</td>
320327
{% endif %}
328+
{% if 'trivariate' in accuracy_table_by_column %}
329+
<td>{{ "{:.1%}".format(row['trivariate']) }}</td>
330+
{% endif %}
321331
{% if 'coherence' in accuracy_table_by_column %}
322332
<td>{{ "{:.1%}".format(row['coherence']).replace('nan%', '-') }}</td>
323333
{% endif %}
@@ -327,12 +337,15 @@ <h2 id="accuracy" class="anchor">Accuracy</h2>
327337
<thead>
328338
<tr>
329339
<th>Total</th>
330-
<th>{{ "{:.1%}".format(metrics.accuracy.univariate) }}</th>
340+
<th>{{ "{:.1%}".format(metrics.accuracy.univariate) }} <small class="muted-text">({{ "{:.1%}".format(metrics.accuracy.univariate_max) }})</small></th>
331341
{% if 'bivariate' in accuracy_table_by_column %}
332-
<th>{{ "{:.1%}".format(metrics.accuracy.bivariate) }}</th>
342+
<th>{{ "{:.1%}".format(metrics.accuracy.bivariate) }} <small class="muted-text">({{ "{:.1%}".format(metrics.accuracy.bivariate_max) }})</small></th>
343+
{% endif %}
344+
{% if 'trivariate' in accuracy_table_by_column %}
345+
<th>{{ "{:.1%}".format(metrics.accuracy.trivariate) }} <small class="muted-text">({{ "{:.1%}".format(metrics.accuracy.trivariate_max) }})</small></th>
333346
{% endif %}
334347
{% if 'coherence' in accuracy_table_by_column %}
335-
<th>{{ "{:.1%}".format(metrics.accuracy.coherence) }}</th>
348+
<th>{{ "{:.1%}".format(metrics.accuracy.coherence) }} <small class="muted-text">({{ "{:.1%}".format(metrics.accuracy.coherence_max) }})</small></th>
336349
{% endif %}
337350
</tr>
338351
</thead>
@@ -355,7 +368,7 @@ <h2 id="accuracy" class="anchor">Accuracy</h2>
355368
<div class="explainer-body">
356369
Accuracy of synthetic data is assessed by comparing the distributions of the synthetic (shown in green) and the original data (shown in gray).
357370
For each distribution plot we sum up the deviations across all categories, to get the so-called total variation distance (TVD). The reported accuracy is then simply reported as 100% - TVD.
358-
These accuracies are calculated for all univariate and bivariate distributions. A final accuracy score is then calculated as the average across all of these.
371+
These accuracies are calculated for all univariate, bivariate and trivariate distributions. A final accuracy score is then calculated as the average across all of these.
359372
</div>
360373
</div>
361374
</div>

mostlyai/qa/metrics.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ class Accuracy(CustomBaseModel):
2727
1. **Univariate Accuracy**: The accuracy of the univariate distributions for all target columns.
2828
2. **Bivariate Accuracy**: The accuracy of all pair-wise distributions for target columns, as well as for target
2929
columns with respect to the context columns.
30-
3. **Coherence Accuracy**: The accuracy of the auto-correlation for all target columns.
30+
3. **Trivariate Accuracy**: The accuracy of all three-way distributions for target columns.
31+
4. **Coherence Accuracy**: The accuracy of the auto-correlation for all target columns.
3132
3233
Accuracy is defined as 100% - [Total Variation Distance](https://en.wikipedia.org/wiki/Total_variation_distance_of_probability_measures) (TVD),
3334
whereas TVD is half the sum of the absolute differences of the relative frequencies of the corresponding
@@ -60,6 +61,12 @@ class Accuracy(CustomBaseModel):
6061
ge=0.0,
6162
le=1.0,
6263
)
64+
trivariate: float | None = Field(
65+
default=None,
66+
description="Average accuracy of discretized trivariate distributions.",
67+
ge=0.0,
68+
le=1.0,
69+
)
6370
coherence: float | None = Field(
6471
default=None,
6572
description="Average accuracy of discretized coherence distributions. Only applicable for sequential data.",
@@ -87,6 +94,13 @@ class Accuracy(CustomBaseModel):
8794
ge=0.0,
8895
le=1.0,
8996
)
97+
trivariate_max: float | None = Field(
98+
default=None,
99+
alias="trivariateMax",
100+
description="Expected trivariate accuracy of a same-sized holdout. Serves as a reference for `trivariate`.",
101+
ge=0.0,
102+
le=1.0,
103+
)
90104
coherence_max: float | None = Field(
91105
default=None,
92106
alias="coherenceMax",

0 commit comments

Comments
 (0)