Skip to content

Commit 62700a9

Browse files
committed
remove topK as it is the same as k_max in practice
1 parent 2c32bd6 commit 62700a9

3 files changed

Lines changed: 23 additions & 23 deletions

File tree

src/__config__.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ fair: #post-hoc reranks of the prediction list
2525

2626
eval:
2727
per_instance: True # needed for paired significance tests
28-
topK: ${fair.k_max} # first stage retrieval to enhance efficiency
2928
metrics:
3029
fair: [ndkl, skew] #, exp, expu >> python 3.9+ from FairRankTune
3130
topk: '1,2,${fair.k_max}'

src/adila.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,17 @@ def __str__(self): return f'{self.attribute}.{self.fair_notion}.{self.is_popular
2727
def _get_labeled_sorted_preds(self, preds, minorities, k_max):
2828
if not preds.is_sparse: sorted_probs, sorted_indices = preds.sort(dim=1, descending=True) # |Test| * |Experts|
2929
else: #|Test| * |topK == k_max|, we need to avoid working with dense
30+
preds = preds.coalesce()
3031
rows, cols = preds.indices()
3132
vals = preds.values()
3233
order = torch.argsort(rows * (vals.max() + 1) - vals) # row-wise descending sort
3334
rows, cols, vals = rows[order], cols[order], vals[order]
3435
# print(torch.bincount(rows))
3536
splits = torch.split(cols, torch.bincount(rows).tolist())
3637
probs_splits = torch.split(vals, torch.bincount(rows).tolist())
37-
# pad each row to topK==k_max (or preds.size(1) but that would be dense again)
38-
sorted_indices = torch.stack([torch.cat([x, torch.tensor([i for i in range(k_max) if i not in x])]) for x in splits]) # pad col idx of zero values
39-
sorted_probs = torch.stack([torch.cat([x, x.new_zeros(k_max - len(x))]) for x in probs_splits]) # pad zero values
38+
# pad each row to k_max (or preds.size(1) but that would be dense again)
39+
sorted_indices = torch.stack([torch.cat([x, torch.tensor([i for i in range(k_max) if i not in x])]) if len(x) < k_max else x[:k_max] for x in splits]) # pad col idx of zero values
40+
sorted_probs = torch.stack([torch.cat([x, x.new_zeros(k_max - len(x))]) if len(x) < k_max else x[:k_max] for x in probs_splits]) # pad zero values
4041

4142
sorted_labels = (sorted_indices[..., None] == torch.tensor(minorities)).any(dim=-1)
4243
## if |experts| are small/mid scale >> dense vector of boolean labels
@@ -176,28 +177,28 @@ def rerank(self, fpred, minorities, ratios, algorithm='det_greedy', k_max=100, a
176177
with open(fpred_, 'wb') as f: pickle.dump(preds_, f)
177178
return preds, preds_, fpred_
178179

179-
def eval_fair(self, preds, minorities, preds_, fpred_, ratios, topK, metrics=['skew', 'ndkl'], per_instance=False):
180+
def eval_fair(self, preds, minorities, preds_, fpred_, ratios, k_max, metrics=['skew', 'ndkl'], per_instance=False):
180181
"""
181182
Args:
182183
preds: loaded predictions from a .pred file
183184
minorities: list of popular or female labels (true labels)
184185
preds_, fpred_: re-ranked probs considering a cut-off min(k_max, preds.shape[1]) and the stored filename
185186
ratios: inferred or a desired ratio of minorities
186-
topK: cutoff for fair reranking methods, ideally should be equal to k_max in reranking
187+
k_max: cutoff for fair reranking methods
187188
metrics: fairness evaluation metrics
188189
per_instance: evaluation metric value for each test team instance
189190
Returns:
190191
None but the results are stored in *.csv files
191192
"""
192-
log.info(f'{opentf.textcolor["green"]}Fairness evaluation for {fpred_} using {metrics} with {topK} cutoff ...{opentf.textcolor["reset"]}')
193+
log.info(f'{opentf.textcolor["green"]}Fairness evaluation for {fpred_} using {metrics} with {k_max} cutoff ...{opentf.textcolor["reset"]}')
193194
frr = opentf.install_import('reranking') # for ndkl and skew
194-
teams = self._get_labeled_sorted_preds(preds, minorities) # [5, 1, 3, 6, 2, 0, 4] -> [0.8, 0.5, 0.4, 0.3, 0.3, 0.1, 0.1]
195-
teams_ = self._get_labeled_sorted_preds(preds_, minorities) # [2, 0, 1, 5, 4, 3, 6] -> [0.8, 0.5, 0.4, 0.3, 0.3, 0.1, 0.1]
195+
teams = self._get_labeled_sorted_preds(preds, minorities, k_max) # [5, 1, 3, 6, 2, 0, 4] -> [0.8, 0.5, 0.4, 0.3, 0.3, 0.1, 0.1]
196+
teams_ = self._get_labeled_sorted_preds(preds_, minorities, k_max) # [2, 0, 1, 5, 4, 3, 6] -> [0.8, 0.5, 0.4, 0.3, 0.3, 0.1, 0.1]
196197

197198
results = []
198-
topK = min(topK, preds.shape[1])
199+
k_max = min(k_max, preds.shape[1])
199200
for i, (team, team_) in enumerate(tqdm(zip(teams, teams_))):
200-
lsteam, lsteam_ = team[:, 1][:topK].bool().tolist(), team_[:, 1][:topK].bool().tolist()
201+
lsteam, lsteam_ = team[:, 1][:k_max].bool().tolist(), team_[:, 1][:k_max].bool().tolist()
201202
if self.fair_notion == 'eo': r = min(max(ratios[i], 0.1), 0.9) # dynamic ratio r, clamps to stay between [0.1,0.9]
202203
else: r = ratios[0]
203204

@@ -208,10 +209,10 @@ def eval_fair(self, preds, minorities, preds_, fpred_, ratios, topK, metrics=['s
208209
result[f'after.{metric}'] = frr.ndkl(lsteam_, {True: r, False: 1 - r})
209210

210211
if 'skew' == metric:
211-
result[f'before.{metric}.minority'] = frr.skew(lsteam.count(True)/topK, r)
212-
result[f'before.{metric}.majority'] = frr.skew(lsteam.count(False)/topK, 1 - r)
213-
result[f'after.{metric}.minority'] = frr.skew(lsteam_.count(True)/topK, r)
214-
result[f'after.{metric}.majority'] = frr.skew(lsteam_.count(False)/topK, 1 - r)
212+
result[f'before.{metric}.minority'] = frr.skew(lsteam.count(True)/k_max, r)
213+
result[f'before.{metric}.majority'] = frr.skew(lsteam.count(False)/k_max, 1 - r)
214+
result[f'after.{metric}.minority'] = frr.skew(lsteam_.count(True)/k_max, r)
215+
result[f'after.{metric}.majority'] = frr.skew(lsteam_.count(False)/k_max, 1 - r)
215216

216217
# if metric in ['exp', 'expu']:
217218
# frt = opentf.install_import('FairRankTune') #python 3.9+
@@ -246,19 +247,19 @@ def eval_fair(self, preds, minorities, preds_, fpred_, ratios, topK, metrics=['s
246247
df_mean.rename(columns={'before': 'mean.before', 'after': 'mean.after'}).to_csv(f'{fpred_}.eval.fair.mean.csv', index=False)
247248
log.info(f'Saved at {fpred_}.eval.fair.mean{"/instance" if per_instance else ""}.csv.')
248249

249-
def eval_utility(self, preds, fpred, preds_, fpred_, topK, metrics, per_instance=False) -> None:
250+
def eval_utility(self, preds, fpred, preds_, fpred_, k_max, metrics, per_instance=False) -> None:
250251
"""
251252
Args:
252253
preds: the file for the predictions, *.pred file
253254
preds_: the file for the re-ranked probs considering a cut-off min(k_max, preds.shape[1]) and the stored filename
254-
topK: first stage retrieval for efficiency
255+
k_max: cutoff for fair reranking methods
255256
metrics: utility evaluation metrics
256257
per_instance: evaluation metric value for each test team instance
257258
Returns:
258259
None but the results are stored in *.csv files
259260
"""
260261

261-
def _evaluate(Y_, metrics, per_instance, topK):
262+
def _evaluate(Y_, metrics, per_instance, k_max):
262263
df, df_mean = pd.DataFrame(), pd.DataFrame()
263264
if not (metrics.trec or metrics.other): df, df_mean
264265
# evl = opentf.install_import('evl.metric', 'metric_')
@@ -268,7 +269,7 @@ def _evaluate(Y_, metrics, per_instance, topK):
268269
# from https://github.com/fani-lab/OpeNTF/blob/main/src/mdl/ntf.py#L59
269270
if metrics.trec:
270271
log.info(f'{metrics.trec} ...')
271-
df, df_mean = evl.calculate_metrics(Y, Y_, topK, per_instance, metrics.trec)
272+
df, df_mean = evl.calculate_metrics(Y, Y_, k_max, per_instance, metrics.trec)
272273
if (m := [m for m in metrics.other if 'aucroc' in m]):
273274
log.info(f'{m} ...')
274275
aucroc, _ = evl.calculate_auc_roc(Y, Y_)
@@ -296,7 +297,7 @@ def _evaluate(Y_, metrics, per_instance, topK):
296297
if per_instance: df_before = pd.read_csv(f'{fpred}.eval.instance.csv', header=0)
297298
except FileNotFoundError:
298299
log.info(f'Before: Loading {fpred}.eval.mean.csv failed! Evaluating from scratch ...')
299-
df_before, df_before_mean = _evaluate(preds, metrics, per_instance, topK)
300+
df_before, df_before_mean = _evaluate(preds, metrics, per_instance, k_max)
300301
if per_instance: df_before.to_csv(f'{fpred}.eval.instance.csv', float_format='%.5f', index=False)
301302
log.info(f'Before: Saving {fpred}.eval.mean.csv ...')
302303
df_before_mean.to_csv(f'{fpred}.eval.mean.csv')
@@ -305,7 +306,7 @@ def _evaluate(Y_, metrics, per_instance, topK):
305306
df_before_mean.rename(columns={'mean': 'mean.before'}, inplace=True)
306307

307308
log.info(f'After: Evaluating {fpred_} ...')
308-
df_after, df_after_mean = _evaluate(preds_, metrics, per_instance, topK)
309+
df_after, df_after_mean = _evaluate(preds_, metrics, per_instance, k_max)
309310
if per_instance: df_after.rename(columns={c: f'{c}.after' for c in df_after.columns}, inplace=True)
310311
df_after_mean.rename(columns={'mean': 'mean.after'}, inplace=True)
311312
if per_instance: pd.concat([df_before.reset_index(drop=True), df_after.reset_index(drop=True)], axis=1).to_csv(f'{fpred_}.eval.utility.instance.csv', float_format='%.5f', index=False)

src/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ def init_process(): logging.basicConfig(level=logging.INFO)
66

77
def __(fpred, adila, minorities, ratios, algorithm, k_max, alpha, evalcfg):
88
preds, preds_, fpred_ = adila.rerank(fpred, minorities, ratios, algorithm, k_max, alpha)
9-
adila.eval_fair(preds, minorities, preds_, fpred_, ratios, evalcfg.topK, evalcfg.metrics.fair, evalcfg.per_instance)
10-
adila.eval_utility(preds, fpred, preds_, fpred_, evalcfg.topK, evalcfg.metrics, evalcfg.per_instance)
9+
adila.eval_fair(preds, minorities, preds_, fpred_, ratios, k_max, evalcfg.metrics.fair, evalcfg.per_instance)
10+
adila.eval_utility(preds, fpred, preds_, fpred_, k_max, evalcfg.metrics, evalcfg.per_instance)
1111

1212
def _(adila, fpred, minorities, ratios, algorithm, k_max, alpha, acceleration, evalcfg):
1313
if os.path.isfile(fpred): __(fpred, adila, minorities, ratios, algorithm, k_max, alpha, evalcfg)

0 commit comments

Comments
 (0)