Skip to content

Commit 2dea196

Browse files
vizier-teamcopybara-github
authored andcommitted
Internal
PiperOrigin-RevId: 744067296
1 parent e76d178 commit 2dea196

File tree

1 file changed

+11
-3
lines changed
  • vizier/_src/algorithms/designers/gp

1 file changed

+11
-3
lines changed

vizier/_src/algorithms/designers/gp/yjt.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,24 @@ def optimal_transformation(
7979
lambdas = preprocessing.PowerTransformer(
8080
method, standardize=False).fit(data).lambdas_.astype(dtype)
8181

82-
logging.info('Optimal lambda was: %s', lambdas)
82+
logging.info('Optimal lambda was: %s, %s', lambdas, lambdas.dtype)
8383

8484
if dimension == 1:
8585
# Make it a scalar, so we don't end up with batch_shape = [1] in the
8686
# bijector.
8787
lambdas = lambdas.item()
8888
if method == 'yeo-johnson':
89-
warp = tfsb.YeoJohnson(lambdas)
89+
# Cast the default values of `rho` and `shift` to the same dtype as `data`
90+
# to avoid dtype mismatch errors.
91+
warp = tfsb.YeoJohnson(
92+
lambdas, rho=np.asarray(2.0, dtype=dtype), shift=np.asarray(1.0, dtype)
93+
)
9094
elif method == 'box-cox':
91-
warp = tfsb.YeoJohnson(lambdas, shift=.0)
95+
# Cast the default values of `rho` and `shift` to the same dtype as `data`
96+
# to avoid dtype mismatch errors.
97+
warp = tfsb.YeoJohnson(
98+
lambdas, rho=np.asarray(2.0, dtype), shift=np.asarray(0.0, dtype)
99+
)
92100
else:
93101
raise ValueError(f'Unknown method: {method}')
94102

0 commit comments

Comments
 (0)