@@ -58,8 +58,12 @@ def _compute_dist2_matrix_scaling(params, batch_rows, batch_columns, periods=Non
58
58
distances between all points in 'batch_rows' and all points in 'batch_columns'.
59
59
"""
60
60
diffs = batch_rows [:, jnp .newaxis , :] - batch_columns [jnp .newaxis , :, :]
61
- if periods is not None : # nonperiodic features must have entry '0'
62
- diffs -= jnp .where (periods , 1.0 , 0.0 ) * jnp .round (diffs / periods ) * periods
61
+ if periods is not None :
62
+ periodic_mask = periods > 0 # only shift periodic features
63
+ periodic_shifts = (
64
+ jnp .round (diffs / jnp .where (periodic_mask , periods , 1.0 )) * periods
65
+ )
66
+ diffs -= jnp .where (periodic_mask , periodic_shifts , 0.0 )
63
67
diffs *= params [jnp .newaxis , jnp .newaxis , :]
64
68
dist2_matrix = jnp .sum (diffs * diffs , axis = - 1 )
65
69
return dist2_matrix
@@ -286,10 +290,12 @@ def _compute_rank_matrix(batch_rows, batch_columns, periods):
286
290
to itself (when a point appears both in batch_rows and batch_columns).
287
291
"""
288
292
diffs = batch_rows [:, jnp .newaxis , :] - batch_columns [jnp .newaxis , :, :]
289
- if periods is not None : # nonperiodic features must have entry '0'
290
- diffs -= (
291
- jnp .where (periods , 1.0 , 0.0 ) * jnp .round (diffs / periods ) * periods
293
+ if periods is not None :
294
+ periodic_mask = periods > 0 # only shift periodic features
295
+ periodic_shifts = (
296
+ jnp .round (diffs / jnp .where (periodic_mask , periods , 1.0 )) * periods
292
297
)
298
+ diffs -= jnp .where (periodic_mask , periodic_shifts , 0.0 )
293
299
dist2_matrix = jnp .sum (diffs * diffs , axis = - 1 )
294
300
rank_matrix = dist2_matrix .argsort (axis = 1 ).argsort (axis = 1 )
295
301
return rank_matrix
0 commit comments