Skip to content

Commit 368488d

Browse files
committed
fix bug with mixed periodic/nonperiodic features
1 parent c838604 commit 368488d

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

dadapy/diff_imbalance.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,12 @@ def _compute_dist2_matrix_scaling(params, batch_rows, batch_columns, periods=Non
5858
distances between all points in 'batch_rows' and all points in 'batch_columns'.
5959
"""
6060
diffs = batch_rows[:, jnp.newaxis, :] - batch_columns[jnp.newaxis, :, :]
61-
if periods is not None: # nonperiodic features must have entry '0'
62-
diffs -= jnp.where(periods, 1.0, 0.0) * jnp.round(diffs / periods) * periods
61+
if periods is not None:
62+
periodic_mask = periods > 0 # only shift periodic features
63+
periodic_shifts = (
64+
jnp.round(diffs / jnp.where(periodic_mask, periods, 1.0)) * periods
65+
)
66+
diffs -= jnp.where(periodic_mask, periodic_shifts, 0.0)
6367
diffs *= params[jnp.newaxis, jnp.newaxis, :]
6468
dist2_matrix = jnp.sum(diffs * diffs, axis=-1)
6569
return dist2_matrix
@@ -286,10 +290,12 @@ def _compute_rank_matrix(batch_rows, batch_columns, periods):
286290
to itself (when a point appears both in batch_rows and batch_columns).
287291
"""
288292
diffs = batch_rows[:, jnp.newaxis, :] - batch_columns[jnp.newaxis, :, :]
289-
if periods is not None: # nonperiodic features must have entry '0'
290-
diffs -= (
291-
jnp.where(periods, 1.0, 0.0) * jnp.round(diffs / periods) * periods
293+
if periods is not None:
294+
periodic_mask = periods > 0 # only shift periodic features
295+
periodic_shifts = (
296+
jnp.round(diffs / jnp.where(periodic_mask, periods, 1.0)) * periods
292297
)
298+
diffs -= jnp.where(periodic_mask, periodic_shifts, 0.0)
293299
dist2_matrix = jnp.sum(diffs * diffs, axis=-1)
294300
rank_matrix = dist2_matrix.argsort(axis=1).argsort(axis=1)
295301
return rank_matrix

0 commit comments

Comments
 (0)